Simple-Image-Classification
Simple Image Classification Code (PyTorch)
This repository contains:
- Python3 / Pytorch code for multi-class image classification
Prerequisites
- See
requirements.txt
for details.
torch
torchvision
matplotlib
scikit-learn
tqdm # not mandatory but recommended
tensorboard # not mandatory but recommended
How to use
- The directory structure of your dataset should be as follows. (You can use our toy-examples: unzip
cifar10_dummy.zip
.)
|โโ ๐ your_own_dataset
|โโ ๐ train
|โโ ๐ class_1
|โโ ๐ผ๏ธ 1.jpg
|โโ ...
|โโ ๐ class_2
|โโ ๐ผ๏ธ ...
|โโ ๐ valid
|โโ ๐ class_1
|โโ ๐ ...
|โโ ๐ test
|โโ ๐ class_1
|โโ ๐ ...
- Check
__init__.py
. You might need to modify variables and add somethings (transformation, optimizer, lr_schduler ...).๐ Tip
You can add your own loss function as follows:
...
def get_loss_function(loss_function_name, device):
...
elif loss_function_name == 'your_own_function_name': # add +
return your_own_function()
...
...
- Run
train.py
for training. The below is an example. Seesrc/my_utils/parser.py
for details.๐ Tip
--loss_function='CE'
means that you choose softmax-cross-entropy (default) for your loss.
python train.py --network_name='resnet34_for_tiny' --dataset_dir='./cifar10_dummy' \
--batch_size=256 --epochs=5 \
--lr=0.1 --lr_step='[60, 120, 160]' --lr_step_gamma=0.5 --lr_warmup_epochs=5 \
--auto_mean_std --store_weights --store_loss_acc_log --store_logits --store_confusion_matrix \
--loss_function='your_own_function_name' --transform_list_name='CIFAR' --tag='train-001'
- Run
test.py
for test. The below is an example. See src/my_utils/parser.py for details.
python test.py --network_name='resnet34_for_tiny' --dataset_dir='./cifar10_dummy' \
--auto_mean_std --store_logits --store_confusion_matrix \
--checkpoint='pretrained_model_weights.pt'
Trailer
-
If you install tqdm, you can check the progress of training.
-
If you install tensorboard, you can check the acc/loss changes and confusion matrices during training.
Contribution