MLP Mixer
Implementation for paper MLP-Mixer: An all-MLP Architecture for Vision. Give us a star if you like this repo.
Author:
- Github: bangoc123
- Email: [email protected]
This library belongs to our project: Papers-Videos-Code where we will implement AI SOTA papers and publish all source code. Additionally, videos to explain these models will be uploaded to ProtonX Youtube channels.
[Note] You can use your data to train this model.
I. Set up environment
-
Make sure you have installed Miniconda. If not yet, see the setup document here.
-
cd
intomlp-mixer
and use command lineconda env create -f environment.yml
to setup the environment -
Run conda environment using the command
conda activate mlp-mixer
II. Set up your dataset.
Create 2 folders train
and validation
in the data
folder (which was created already). Then Please copy
your images with the corresponding names into these folders.
train
folder was used for the training processvalidation
folder was used for validating training result after each epoch
This library use image_dataset_from_directory
API from Tensorflow 2.0
to load images. Make sure you have some understanding of how it works via its document.
Structure of these folders.
train/
...class_a/
......a_image_1.jpg
......a_image_2.jpg
...class_b/
......b_image_1.jpg
......b_image_2.jpg
...class_c/
......c_image_1.jpg
......c_image_2.jpg
validation/
...class_a/
......a_image_1.jpg
......a_image_2.jpg
...class_b/
......b_image_1.jpg
......b_image_2.jpg
...class_c/
......c_image_1.jpg
......c_image_2.jpg
III. Train your model by running this command line
python train.py --epochs ${epochs} --num-classes ${num_classes}
You want to train a model in 10 epochs for binary classification problems (with 2 classes)
Example:
python train.py --epochs 10 --num-classes 2
There are some important
arguments for the script you should consider when running it:
train-folder
: The folder of training imagesvalid-folder
: The folder of validation imagesmodel-folder
: Where the model after training savednum-classes
: The number of your problem classes.batch-size
: The batch size of the datasetc
: Patch Projection Dimensiondc
: Token-mixing units. It was mentioned in the paper on page 3ds
: Channel-mixing units. It was mentioned in the paper on page 3num-of-mlp-blocks
: The number of MLP Blockslearning-rate
: The learning rate of Adam Optimizer
After training successfully, your model will be saved to model-folder
defined before
IV. Testing model with a new image
We offer a script for testing a model using a new image via a command line:
python predict.py --test-file-path ${test_file_path}
where test_file_path
is the path of your test image.
Example:
python predict.py --test-file-path ./data/test/cat.2000.jpg
V. Feedback
If you meet any issues when using this library, please let us know via the issues submission tab.