Swin Transformer (Tensorflow)
Tensorflow reimplementation of Swin Transformer model.
Based on Official Pytorch implementation.
Requirements
tensorflow >= 2.4.1
Pretrained Swin Transformer Checkpoints
ImageNet-1K and ImageNet-22K Pretrained Checkpoints
name | pretrain | resolution | acc@1 | #params | model |
---|---|---|---|---|---|
swin_tiny_224 |
ImageNet-1K | 224x224 | 81.2 | 28M | github |
swin_small_224 |
ImageNet-1K | 224x224 | 83.2 | 50M | github |
swin_base_224 |
ImageNet-22K | 224x224 | 85.2 | 88M | github |
swin_base_384 |
ImageNet-22K | 384x384 | 86.4 | 88M | github |
swin_large_224 |
ImageNet-22K | 224x224 | 86.3 | 197M | github |
swin_large_384 |
ImageNet-22K | 384x384 | 87.3 | 197M | github |
Examples
Initializing the model:
from swintransformer import SwinTransformer
model = SwinTransformer('swin_tiny_224', num_classes=1000, include_top=True, pretrained=False)
You can use a pretrained model like this:
import tensorflow as tf
from swintransformer import SwinTransformer
model = tf.keras.Sequential([
tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
SwinTransformer('swin_tiny_224', include_top=False, pretrained=True),
tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])
If you use a pretrained model with TPU on kaggle, specify use_tpu
option:
import tensorflow as tf
from swintransformer import SwinTransformer
model = tf.keras.Sequential([
tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])
Example: TPU training on Kaggle
Citation
@article{liu2021Swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
journal={arXiv preprint arXiv:2103.14030},
year={2021}
}