Code for the Convolutional Vision Transformer (ConViT)

Related tags

Deep Learning convit
Overview

ConViT : Vision Transformers with Convolutional Inductive Biases

This repository contains PyTorch code for ConViT. It builds on code from the Data-Efficient Vision Transformer and from timm.

For details see the ConViT paper by Stéphane d'Ascoli, Hugo Touvron, Matthew Leavitt, Ari Morcos, Giulio Biroli and Levent Sagun.

If you use this code for a paper please cite:

@article{d2021convit,
  title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
  author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
  journal={arXiv preprint arXiv:2103.10697},
  year={2021}
}

Usage

Install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:

conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Evaluation

To evaluate ConViT-Ti on ImageNet test set, run:

python main.py --eval --model convit_tiny --pretrained --data-path /path/to/imagenet

This should give

Acc@1 73.116 Acc@5 91.710 loss 1.172

Training

To train ConViT-Ti on ImageNet on a single node with 4 gpus for 300 epochs run:

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model convit_tiny --batch-size 256 --data-path /path/to/imagenet

To train the same model on a subsampled version of ImageNet where we only use 10% of the images of each class, add --sampling_ratio 0.1

Multinode training

Distributed training is available via Slurm and submitit:

pip install submitit

To train ConViT-base on ImageNet on 2 nodes with 8 gpus each for 300 epochs:

python run_with_submitit.py --model convit_base --data-path /path/to/imagenet

License

The majority of this repository is released under the CC-BY-NC 4.0. license as found in the LICENSE file, however portions of the project are available under separate license terms: deit and timm are licensed under Apache 2.0.

Comments
  • How to train on custom data using Colab?

    How to train on custom data using Colab?

    I am getting this error: Namespace(aa='rand-m9-mstd0.5-inc1', batch_size=64, clip_grad=None, color_jitter=0.4, cooldown_epochs=10, cutmix=1.0, cutmix_minmax=None, data_path='/content/stonks', data_set='IMNET', decay_epochs=30, decay_rate=0.1, device='cuda', dist_url='env://', distributed=False, drop=0.0, drop_block=None, drop_path=0.1, embed_dim=48, epochs=300, eval=False, inat_category='name', input_size=224, local_up_to_layer=10, locality_strength=1.0, lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, min_lr=1e-05, mixup=0.8, mixup_mode='batch', mixup_prob=1.0, mixup_switch_prob=0.5, model='convit_tiny', model_ema=False, model_ema_decay=0.99996, model_ema_force_cpu=False, momentum=0.9, nb_classes=None, num_workers=10, opt='adamw', opt_betas=None, opt_eps=1e-08, output_dir='', patience_epochs=10, pin_mem=True, pretrained=False, recount=1, remode='pixel', repeated_aug=True, reprob=0.25, resplit=False, resume='', sampling_ratio=1.0, save_every=None, sched='cosine', seed=0, smoothing=0.1, start_epoch=0, train_interpolation='bicubic', warmup_epochs=5, warmup_lr=1e-06, weight_decay=0.05, world_size=1) Traceback (most recent call last): File "/content/convit/main.py", line 383, in main(args) File "/content/convit/main.py", line 194, in main dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) File "/content/convit/datasets.py", line 148, in build_dataset sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes) File "/content/convit/datasets.py", line 130, in init is_valid_file=is_valid_file, **kwargs) File "/content/convit/datasets.py", line 107, in init classes, class_to_idx = self._find_classes(self.root) AttributeError: 'ImageNetDataset' object has no attribute '_find_classes'

    opened by Tylersuard 4
  • Adding a pre save stage

    Adding a pre save stage

    Use case:

    Makes Modification to allow a pre training save of model state dictionary so that a 0th state can be saved when undertaking meta-heuristic learning or search of hyper parameters.

    There may be other use cases where base checkpoint.path is necessary.

    such as visualizing accuracy, loss function (y) epoch n (x) where a pre-fb-convit trained state is us full for evaluation.

    To follow as part of my own research:

    Grid search of hyper parameters with convit code running

    CLA Signed 
    opened by fdsig 3
  • Import error when running in colab

    Import error when running in colab

    Hello, I'm trying to run the evaluation or training codes given on colab: !python /content/convit/main.py --eval --model convit_tiny --pretrained --data-path /content/drive/MyDrive/convitry/ and: !python -m torch.distributed.launch --use_env /content/convit/main.py --model convit_tiny --batch-size 256 --data-path /content/drive/MyDrive/convitry/ and I keep getting this error:

    ImportError: cannot import name 'container_abcs' from 'torch._six' (/usr/local/lib/python3.7/dist-packages/torch/_six.py)

    Can you help? Also any notes on running the code on colab?

    opened by NouranFadlallah 2
  • Multi-GPUs training

    Multi-GPUs training

    Hello, I have some questions about Multi-GPUs training as follows: I don't know how to add parameters when I use function as CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main.py I want to add kinetics RGB --arch tea50 --num_segments 8 --gpus 0 1 2 3 4 5 6 7 --gd 20 --lr 0.0009 --lr_steps 20 40 --epochs 200 --batch-size 16 -j 16 --dropout 0.5 --consensus_type=avg --eval-freq=1 --print-freq=200 --experiment_name=TEA --shift --shift_div=8 --resume 0 --shift_place=blockres Can you help me? I'm overwhelmed with gratitude

    opened by Administor123 2
  • Question on equation (5)

    Question on equation (5)

    Hi,

    Thanks for your great work and repo! I just have a quick question on equation (5) of your paper.

    In your paper and code: image

    It seems you have forced the value projection matrix to be an identity matrix all the time to make sure quadratic encoding works. However, in paper On the Relationship between Self-Attention and Convolutional Layers, the quadratic encoding equation is defined differently as: image, that the \hat{key} projection matrix that calculates relative position embedding: image is kept as an identity matrix, which is different to your implementation since you have kept both \hat{key} and value matrix as identity matrices.

    Can you kindly let me know whether I have misunderstood anything?

    Kind regards, Haoyu

    opened by Charleshhy 2
  • The weights won't load

    The weights won't load

    Unexpected key(s) in state_dict: "blocks.3.attn.gating_param", "blocks.3.attn.qk.weight", "blocks.3.attn.v.weight", "blocks.3.attn.pos_proj.weight", "blocks.3.attn.pos_proj.bias", "blocks.4.attn.gating_param", "blocks.4.attn.qk.weight", "blocks.4.attn.v.weight", "blocks.4.attn.pos_proj.weight", "blocks.4.attn.pos_proj.bias", "blocks.5.attn.gating_param", "blocks.5.attn.qk.weight", "blocks.5.attn.v.weight", "blocks.5.attn.pos_proj.weight", "blocks.5.attn.pos_proj.bias", "blocks.6.attn.gating_param", "blocks.6.attn.qk.weight", "blocks.6.attn.v.weight", "blocks.6.attn.pos_proj.weight", "blocks.6.attn.pos_proj.bias", "blocks.7.attn.gating_param", "blocks.7.attn.qk.weight", "blocks.7.attn.v.weight", "blocks.7.attn.pos_proj.weight", "blocks.7.attn.pos_proj.bias", "blocks.8.attn.gating_param", "blocks.8.attn.qk.weight", "blocks.8.attn.v.weight", "blocks.8.attn.pos_proj.weight", "blocks.8.attn.pos_proj.bias", "blocks.9.attn.gating_param", "blocks.9.attn.qk.weight", "blocks.9.attn.v.weight", "blocks.9.attn.pos_proj.weight", "blocks.9.attn.pos_proj.bias".
            size mismatch for cls_token: copying a param with shape torch.Size([1, 1, 192]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
            size mismatch for pos_embed: copying a param with shape torch.Size([1, 196, 192]) from checkpoint, the shape in current model is torch.Size([1, 196, 768]).
            size mismatch for patch_embed.proj.weight: copying a param with shape torch.Size([192, 3, 16, 16]) from checkpoint, the shape in current model is torch.Size([768, 3, 16, 16]).
            size mismatch for patch_embed.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.0.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.0.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.0.attn.qk.weight: copying a param with shape torch.Size([384, 192]) from checkpoint, the shape in current model is torch.Size([1536, 768]).
            size mismatch for blocks.0.attn.v.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.0.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.0.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.0.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.0.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.0.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.1.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.1.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.1.attn.qk.weight: copying a param with shape torch.Size([384, 192]) from checkpoint, the shape in current model is torch.Size([1536, 768]).
            size mismatch for blocks.1.attn.v.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.1.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.1.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.1.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.1.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.1.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.2.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.2.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.2.attn.qk.weight: copying a param with shape torch.Size([384, 192]) from checkpoint, the shape in current model is torch.Size([1536, 768]).
            size mismatch for blocks.2.attn.v.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.2.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.2.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.2.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.2.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.2.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.2.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.2.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.2.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.3.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.3.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.3.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.3.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.3.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.3.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.3.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.3.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.3.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.3.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.4.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.4.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.4.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.4.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.4.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.4.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.4.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.4.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.4.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.4.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.5.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.5.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.5.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.5.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.5.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.5.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.5.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.5.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.5.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.5.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.6.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.6.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.6.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.6.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.6.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.6.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.6.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.6.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.6.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.6.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.7.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.7.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.7.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.7.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.7.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.7.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.7.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.7.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.7.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.7.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.8.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.8.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.8.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.8.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.8.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.8.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.8.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.8.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.8.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.8.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.9.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.9.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.9.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.9.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.9.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.9.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.9.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.9.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.9.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.9.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.10.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.10.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.10.attn.qkv.weight: copying a param with shape torch.Size([576, 192]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
            size mismatch for blocks.10.attn.proj.weight: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([768, 768]).
            size mismatch for blocks.10.attn.proj.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.10.norm2.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.10.norm2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.10.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.10.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.10.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.10.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.11.norm1.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.11.norm1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).kpoint, the shape in current model is torch.Size([768]).                                               from checkpoint, the shape in current model is torch.Size([2304, 768]).
            size mismatch for blocks.11.attn.qkv.weight: copying a param with shape torch.Size([576, 192])) from checkpoint, the shape in current model is torch.Size([768, 768]). from checkpoint, the shape in current model is torch.Size([2304, 768]).                              checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.11.attn.proj.weight: copying a param with shape torch.Size([192, 192]eckpoint, the shape in current model is torch.Size([768]).) from checkpoint, the shape in current model is torch.Size([768, 768]).                              kpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.11.attn.proj.bias: copying a param with shape torch.Size([192]) from from checkpoint, the shape in current model is torch.Size([3072, 768]).
    checkpoint, the shape in current model is torch.Size([768]).                                          eckpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.11.norm2.weight: copying a param with shape torch.Size([192]) from chfrom checkpoint, the shape in current model is torch.Size([768, 3072]).eckpoint, the shape in current model is torch.Size([768]).                                            eckpoint, the shape in current model is torch.Size([768]).
            size mismatch for blocks.11.norm2.bias: copying a param with shape torch.Size([192]) from cheche shape in current model is torch.Size([768]).kpoint, the shape in current model is torch.Size([768]).                                               shape in current model is torch.Size([768]).
            size mismatch for blocks.11.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) int, the shape in current model is torch.Size([1000, 768]).
    from checkpoint, the shape in current model is torch.Size([3072, 768]).
            size mismatch for blocks.11.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([3072]).
            size mismatch for blocks.11.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
            size mismatch for blocks.11.mlp.fc2.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for norm.weight: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for norm.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([768]).
            size mismatch for head.weight: copying a param with shape torch.Size([1000, 192]) from checkpoint, the shape in current model is torch.Size([1000, 768]).
    

    May I ask what is the matter? How do I solve it

    opened by lmk123568 2
  • Multilabel classification for training

    Multilabel classification for training

    @likethesky @sdascoli i want to train convit for multi label classification whose input image size is 256 x128 with label as vectors can i do to so , if so what changes i have to make in training pipeline Thanks in advance

    opened by abhigoku10 1
  • Size mismatch on testing the model with pretrained weights.

    Size mismatch on testing the model with pretrained weights.

    Hi, thanks for sharing the code! While I was testing it out with pretrained models on ImageNet21k, I encountered an error:

    RuntimeError: Error(s) in loading state_dict for VisionTransformer:
            size mismatch for cls_token: copying a param with shape torch.Size([1, 1, 192]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
            size mismatch for pos_embed: copying a param with shape torch.Size([1, 196, 192]) from checkpoint, the shape in current model is torch.Size([1, 196, 768]).
            size mismatch for patch_embed.proj.weight: copying a param with shape torch.Size([192, 3, 16, 16]) from checkpoint, the shape in current model is torch.Size([768, 3, 16, 16]).
    .... (more size mismatch messages...)
    

    Looks like the same error encountered by issue #4 (which was addressed by PR #6). But I'm facing the error on all the pretrained models (tiny, small, base). I think I might have found the issue:

    On convit.py#L305: embed_dim *= num_heads was added as part of #6.

    But on models.py#20: kwargs['embed_dim'] *= num_heads already multiplies theembed_dim bynum_heads (same operation for all models).

    So it looks to me that convit.py#L305 is actually redundant. It's also causing the size error above. When I remove that line, the test works as expected.

    opened by jdubpark 1
  • Multi-Gpu training

    Multi-Gpu training

    Hi! Thank you for sharing the code I have some questions about how to train with multi GPU. When I used the original code, it can just run on only one GPU(I found that 'DistributedDataParallel' doesn't work ). After that, I wrote that:

    # if args.distributed:
    #     model = torch.nn.parallel.DistributedDataParallel(model, device_ids='0,1')
    #     model_without_ddp = model.module
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model = DataParallel(model)
    

    But suffer from this message: RuntimeError: Expected tensor for 'out' to have the same device as tensor for argument #3 'mat2'; but device 0 does not equal 1 (while checking arguments for addmm)

    Could you help me to solve this problem? Also, what is the meaning of '--world_size' which default is setting to be '1'?

    Thanks!

    opened by SZUHvern 1
  •  CC-BY-NC 4.0 License

    CC-BY-NC 4.0 License

    Dear researchers hope you're well and thanks for your wonderful work on ConViT. I quite enjoyed reading the paper and wasn't aware that self-attention could also express a convolutional layer [1].

    Would love to get this model as part of TIMM, but as you can see in the discussion here, Ross tells me we need to wait for an unrestrictive license.

    Just wondering if that's something in the works? Is there some time in the near future that we're going to move towards an Apache license? Why do we have a non-commercial license here, please?

    opened by amaarora 1
  • Performance on ImageNet is lower than reported.

    Performance on ImageNet is lower than reported.

    Dear contributors,

    Thanks for releasing your code. We used your codebase to run ConViT-Ti on the ImageNet dataset and achieved 72.5% Top-1 Accuracy, which is 0.6% lower than you reported. Could you please let us know how to reproduce your result? Here is our setting:

    8 V100 GPUs nproc_per_node=8 batch-size=128 The other setting is the same as your main.py file. Here we upload the file for your reference:

    def get_args_parser():
     parser = argparse.ArgumentParser('ConViT training and evaluation script', add_help=False)
     parser.add_argument('--batch-size', default=128, type=int)
     parser.add_argument('--epochs', default=300, type=int)
    
     # Model parameters
     parser.add_argument('--model', default='convit_small', type=str, metavar='MODEL',
                         help='Name of model to train')
     parser.add_argument('--pretrained', action='store_true')
    
     parser.add_argument('--input-size', default=224, type=int, help='images input size')
     parser.add_argument('--embed_dim', default=48, type=int, help='embedding dimension per head')
    
     parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                         help='Dropout rate (default: 0.)')
     parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
                         help='Drop path rate (default: 0.1)')
     parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                         help='Drop block rate (default: None)')
    
     parser.add_argument('--model-ema', action='store_true')
     parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
     parser.set_defaults(model_ema=False)
     parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
     parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
    
     # Optimizer parameters
     parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                         help='Optimizer (default: "adamw"')
     parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                         help='Optimizer Epsilon (default: 1e-8)')
     parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                         help='Optimizer Betas (default: None, use opt default)')
     parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                         help='Clip gradient norm (default: None, no clipping)')
     parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                         help='SGD momentum (default: 0.9)')
     parser.add_argument('--weight-decay', type=float, default=0.05,
                         help='weight decay (default: 0.05)')
     # Learning rate schedule parameters
     parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                         help='LR scheduler (default: "cosine"')
     parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
                         help='learning rate (default: 5e-4)')
     parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                         help='learning rate noise on/off epoch percentages')
     parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                         help='learning rate noise limit percent (default: 0.67)')
     parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                         help='learning rate noise std-dev (default: 1.0)')
     parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
                         help='warmup learning rate (default: 1e-6)')
     parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                         help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    
     parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
                         help='epoch interval to decay LR')
     parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
                         help='epochs to warmup LR, if scheduler supports')
     parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                         help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
     parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                         help='patience epochs for Plateau LR scheduler (default: 10')
     parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                         help='LR decay rate (default: 0.1)')
    
     # Augmentation parameters
     parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                         help='Color jitter factor (default: 0.4)')
     parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                         help='Use AutoAugment policy. "v0" or "original". " + \
                              "(default: rand-m9-mstd0.5-inc1)'),
     parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
     parser.add_argument('--train-interpolation', type=str, default='bicubic',
                         help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
    
     parser.add_argument('--repeated-aug', action='store_true')
     parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
     parser.set_defaults(repeated_aug=True)
    
     # * Random Erase params
     parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                         help='Random erase prob (default: 0.25)')
     parser.add_argument('--remode', type=str, default='pixel',
                         help='Random erase mode (default: "pixel")')
     parser.add_argument('--recount', type=int, default=1,
                         help='Random erase count (default: 1)')
     parser.add_argument('--resplit', action='store_true', default=False,
                         help='Do not random erase first (clean) augmentation split')
    
     # * Mixup params
     parser.add_argument('--mixup', type=float, default=0.8,
                         help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
     parser.add_argument('--cutmix', type=float, default=1.0,
                         help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
     parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                         help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
     parser.add_argument('--mixup-prob', type=float, default=1.0,
                         help='Probability of performing mixup or cutmix when either/both is enabled')
     parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                         help='Probability of switching to cutmix when both mixup and cutmix enabled')
     parser.add_argument('--mixup-mode', type=str, default='batch',
                         help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
    
     # Dataset parameters
     parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
                         help='dataset path')
     parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19'],
                         type=str, help='Image Net dataset path')
     parser.add_argument('--sampling_ratio', default=1.,
                         type=float, help='fraction of samples to keep in the training set of imagenet')
     parser.add_argument('--nb_classes', default=None,
                         type=int, help='number of classes in imagenet')
     parser.add_argument('--inat-category', default='name',
                         choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                         type=str, help='semantic granularity')
    
     parser.add_argument('--output_dir', default='',
                         help='path where to save, empty for no saving')
     parser.add_argument('--eval_freq', default=10, type=int)
     parser.add_argument('--device', default='cuda',
                         help='device to use for training / testing')
     parser.add_argument('--seed', default=0, type=int)
     parser.add_argument('--resume', default='', help='resume from checkpoint')
     parser.add_argument('--save_every', default=None, type=int, help='save model every epochs')
     parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                         help='start epoch')
     parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
     parser.add_argument('--num_workers', default=8, type=int)
     parser.add_argument('--pin-mem', action='store_true',
                         help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
     parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                         help='')
     parser.set_defaults(pin_mem=True)
    
     # distributed training parameters
     parser.add_argument('--world_size', default=1, type=int,
                         help='number of distributed processes')
     parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    
     # locality parameters
     parser.add_argument('--local_up_to_layer', default=10, type=int,
                         help='number of GPSA layers')
     parser.add_argument('--locality_strength', default=1., type=float,
                         help='Determines how focused each head is around its attention center')
    
     return parser
     ```
    Many thanks!
    question 
    opened by FreddieRao 1
  • About Nonlocality

    About Nonlocality

    Thanks for your great work and codes.

    I am a little bit confused about your implementations of nonlocality (in main.py (L346-351))

    Here is the code:

    batch = next(iter(data_loader_val))[0]
    batch = batch.to(device)
    batch = model_without_ddp.patch_embed(batch)
    for l in range(len(model_without_ddp.blocks)):
        attn =  model_without_ddp.blocks[l].attn
        nonlocality[l] = attn.get_attention_map(batch).detach().cpu().numpy().tolist()
    

    It seems that you always feed the original patch embeddings to all 12 blocks. Shouldn't the inputs of attn.get_attention_map be [original patch embeddings, outputs of the block 1, ..., outputs of the block 11]?

    If I understand it wrong, please correct me.

    Sincerely, looking forward to your reply.

    opened by yangbang18 0
  • Visualize Attention Map

    Visualize Attention Map

    Thank you for your work and for providing the code.

    How can we visualize the attention feature map? Any comments would be really helpful.

    Best wishes cenchaojun

    opened by cenchaojun 0
  • Using ConViT for object detection [Discussion]

    Using ConViT for object detection [Discussion]

    Thank you for your work and for providing the code.

    Based on my limited knowledge about the transformers used in detectors and trackers, I see a common trend of using Resnet50 backbone for feature extraction before the transformer layers, and the only paper I could find without a Conv backbone is WB-DETR which introduces a module similar to T2T module which claims captures "rich local information" from the patch. In general, the argument I have read is that transformers are not able to capture the local information from the patch well and hence miss the small objects.

    Do you think ConViT would be able to handle object detection if used as the backbone in the DETR framework instead of Resnet50+Transformer? Intuitively, the GPSA from your paper should be able to capture convolutional features from the image and well as the SA properties based on the loss function from DETR.

    Would be happy to try it out if you think it's possible? I have been trying to find an architecture that doesn't use Conv layers and work on patch-based representation and your work seemed very relevant.

    Any comments would be really helpful.

    Thank you. Saurabh

    opened by sfarkya04 1
  • Using custom data using IMNET option

    Using custom data using IMNET option

    Hello, I'm trying to train the model on a custom medical dataset with two classes (normal and abnormal) and I'm trying to use IMNET option since it loads data from specified directory. However, the model doesn't seem to train (Accuracy is always around 55%). While debugging I noticed the loaded images are totally black.

    when running this code on colab:

    !python -m torch.distributed.launch --use_env /content/convit/main.py --epochs 50 --mixup 0.8 --model convit_base --drop 0.7 --batch-size 32 --nb_classes 2 --output_dir /content/drive/MyDrive/models/convit/ --data-path /content/drive/MyDrive/ct_data/RawConvit/

    I added this code (in main.py after defining data_loader_train) to check the generated images:

        import matplotlib.pyplot as plt
        train_features, train_labels = next(iter(data_loader_train))
        img = train_features[0].squeeze().permute(1, 2, 0)
        label = train_labels[0]
        plt.imshow(img, cmap="gray")
        plt.savefig('img.png')
        print(np.amax(img.numpy()))
        print(np.amin(img.numpy()))
     
    

    where img.png is a totally black image and the print commands give -1.1072767 and -2.117904

    My dataset has the following structure, and the code reads the number of classes / images correctly:

    /RawConvit/  
      train/  
        abnormal/  
          img_1.jpg 
        normal/
          img_2.jpg
      val/
        abnormal/
          img_3.jpg
        normal/
          img_4.jpg
    
    opened by NouranFadlallah 3
Owner
Facebook Research
Facebook Research
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 209 Dec 30, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Alex Pashevich 62 Dec 24, 2022
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

null 87 Nov 29, 2022
Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Swin-Transformer-Tensorflow A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Sh

null 52 Dec 29, 2022
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

Microsoft 409 Jan 6, 2023
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

Jiezhang Cao 225 Nov 13, 2022
This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of Coordinate Independent Convolutional Networks.

Orientation independent Möbius CNNs This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of

Maurice Weiler 59 Dec 9, 2022
CoSMA: Convolutional Semi-Regular Mesh Autoencoder. From Paper "Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes"

Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes Implementation of CoSMA: Convolutional Semi-Regular Mesh Autoencoder arXiv p

Fraunhofer SCAI 10 Oct 11, 2022
CMT: Convolutional Neural Networks Meet Vision Transformers

CMT: Convolutional Neural Networks Meet Vision Transformers [arxiv] 1. Introduction This repo is the CMT model which impelement with pytorch, no refer

FlyEgle 83 Dec 30, 2022
[Preprint] ConvMLP: Hierarchical Convolutional MLPs for Vision, 2021

Convolutional MLP ConvMLP: Hierarchical Convolutional MLPs for Vision Preprint link: ConvMLP: Hierarchical Convolutional MLPs for Vision By Jiachen Li

SHI Lab 143 Jan 3, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 1, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 1, 2023
PyTorch code of my ICDAR 2021 paper Vision Transformer for Fast and Efficient Scene Text Recognition (ViTSTR)

Vision Transformer for Fast and Efficient Scene Text Recognition (ICDAR 2021) ViTSTR is a simple single-stage model that uses a pre-trained Vision Tra

Rowel Atienza 198 Dec 27, 2022
This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT).

Dynamic-Vision-Transformer (Pytorch) This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT). Not All Ima

null 210 Dec 18, 2022
The code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

CrossFormer This repository is the code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention. Introduction Existin

cheerss 238 Jan 6, 2023
Official code for paper "Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight"

Demysitifing Local Vision Transformer, arxiv This is the official PyTorch implementation of our paper. We simply replace local self attention by (dyna

null 138 Dec 28, 2022
Code of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation

TVT Code of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation Datasets: Digit: MNIST, SVHN, USPS Object: Office, Office-Home, Vi

null 37 Dec 15, 2022