Refactoring dalle-pytorch and taming-transformers for TPU VM

Overview

Text-to-Image Translation (DALL-E) for TPU in Pytorch

Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning

Requirements

pip install -r requirements.txt

Data Preparation

Place any image dataset with ImageNet-style directory structure (at least 1 subfolder) to fit the dataset into pytorch ImageFolder.

Training VQVAEs

You can easily test main.py with randomly generated fake data.

python train_vae.py --use_tpus --fake_data

For actual training provide specific directory for train_dir, val_dir, log_dir:

python train_vae.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results]

Training DALL-E

python train_dalle.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results] --vae_path [pretrained vae] --bpe_path [pretrained bpe(optional)]

TODO

  • Refactor Encoder and Decoder modules for better readability
  • Refactor VQVAE2
  • Add Net2Net Conditional Transformer for conditional image generation
  • Refactor, optimize, and merge DALL-E with Net2Net Conditional Transformer
  • Add Guided Diffusion + CLIP for image refinement
  • Add VAE converter for JAX to support dalle-mini
  • Add DALL-E colab notebook
  • Add RBGumbelQuantizer
  • Add HiT

ON-GOING

  • Test large dataset loading on TPU Pods
  • Change current DALL-E code to fully support latest updates from DALLE-pytorch

DONE

  • Add VQVAE, VQGAN, and Gumbel VQVAE(Discrete VAE), Gumbel VQGAN
  • Add VQVAE2
  • Add EMA update for Vector Quantization
  • Debug VAEs (Single TPU Node, TPU Pods, GPUs)
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3028
  • Add DALL-E
  • Debug DALL-E (Single TPU Node, TPU Pods, GPUs)
  • Add WebDataset support
  • Add VAE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add DALLE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add automatic checkpoint saver and resume for sudden (which happens a lot) TPU restart
  • Reimplement EMA VectorQuantizer with nn.Embedding
  • Add DALL-E colab notebook by afiaka87
  • Add Normed Vector Quantizer by GallagherCommaJack
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3068
  • Debug WebDataset functionality

BibTeX

@misc{oord2018neural,
      title={Neural Discrete Representation Learning}, 
      author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
      year={2018},
      eprint={1711.00937},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{razavi2019generating,
      title={Generating Diverse High-Fidelity Images with VQ-VAE-2}, 
      author={Ali Razavi and Aaron van den Oord and Oriol Vinyals},
      year={2019},
      eprint={1906.00446},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{esser2020taming,
      title={Taming Transformers for High-Resolution Image Synthesis}, 
      author={Patrick Esser and Robin Rombach and Björn Ommer},
      year={2020},
      eprint={2012.09841},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{ramesh2021zeroshot,
    title   = {Zero-Shot Text-to-Image Generation}, 
    author  = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
    year    = {2021},
    eprint  = {2102.12092},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • DALL-E training code (often) gets stuck indefinitely after completing validation

    DALL-E training code (often) gets stuck indefinitely after completing validation

    Hi @tgisaturday and @afiaka87, I've been working with your repo for a while and overall I didn't have many problems. However, from last week I started noticing that often the code for training DALLE doesn't exit the process after completing training and validation. In other words, it gets stuck indefinitely in that state and two weeks ago I've never encountered this behavior.

    This is a screenshot when the code just works fine: image

    This is when the code gets stuck: image

    Usually I tried to solve issues by myself, but this time without a error message I have no idea where to start. The only thing I can do now is to restart the colab session and rerun the cells once again because the checkpoint is not saved. Since the issue is quite annoying because it forces me to train the model one epoch at time to minimize waste of time in case the problem shows up, it would be great if you could give me any idea on how I might solve it.

    Thank you both for your help.

    opened by matteopilotto 13
  • Hi, after the update it seems to complain about ch_mult's default value ''train_vae.py: error: argument --ch_mult: invalid int value: '1,1,2,2,4,''

    Hi, after the update it seems to complain about ch_mult's default value ''train_vae.py: error: argument --ch_mult: invalid int value: '1,1,2,2,4,''

    WARNING:root:TPU has started up successfully with version pytorch-1.9 GPU available: False, used: False TPU available: True, using: 8 TPU cores IPU available: False, using: 0 IPUs usage: train_vae.py [-h] [--train_dir TRAIN_DIR] [--val_dir VAL_DIR] [--log_dir LOG_DIR] [--backup_dir BACKUP_DIR] [--ckpt_path CKPT_PATH] [--resume] [--backup] [--backup_steps BACKUP_STEPS] [--log_images] [--image_log_steps IMAGE_LOG_STEPS] [--refresh_rate REFRESH_RATE] [--precision PRECISION] [--fake_data] [--use_tpus] [--seed SEED] [--gpus GPUS] [--gpu_dist] [--tpus TPUS] [--num_sanity_val_steps NUM_SANITY_VAL_STEPS] [--learning_rate LEARNING_RATE] [--lr_decay] [--starting_temp STARTING_TEMP] [--temp_min TEMP_MIN] [--anneal_rate ANNEAL_RATE] [--batch_size BATCH_SIZE] [--epochs EPOCHS] [--num_workers NUM_WORKERS] [--img_size IMG_SIZE] [--resize_ratio RESIZE_RATIO] [--test] [--debug] [--xla_stat] [--web_dataset] [--dataset_size DATASET_SIZE] [--model MODEL] [--codebook_dim CODEBOOK_DIM] [--num_tokens NUM_TOKENS] [--double_z DOUBLE_Z] [--z_channels Z_CHANNELS] [--resolution RESOLUTION] [--in_channels IN_CHANNELS] [--out_channels OUT_CHANNELS] [--hidden_dim HIDDEN_DIM] [--ch_mult CH_MULT [CH_MULT ...]] [--num_res_blocks NUM_RES_BLOCKS] [--attn_resolutions ATTN_RESOLUTIONS [ATTN_RESOLUTIONS ...]] [--dropout DROPOUT] [--quant_beta QUANT_BETA] [--quant_ema_decay QUANT_EMA_DECAY] [--quant_ema_eps QUANT_EMA_EPS] [--num_res_ch NUM_RES_CH] [--smooth_l1_loss] [--kl_loss_weight KL_LOSS_WEIGHT] [--disc_conditional DISC_CONDITIONAL] [--disc_in_channels DISC_IN_CHANNELS] [--disc_start DISC_START] [--disc_weight DISC_WEIGHT] [--codebook_weight CODEBOOK_WEIGHT] [--wandb] train_vae.py: error: argument --ch_mult: invalid int value: '1,1,2,2,4,'

    bug documentation 
    opened by jordanvalter 10
  • Error while Training Dall-E on a single TPU (8cores)

    Error while Training Dall-E on a single TPU (8cores)

    Hi, I am trying to train Dall-e on COCO dataset and here are the parameters I use:

    %%writefile /content/tmp/run.sh
    #@title Configuration
    # model
    model="vqgan" #@param  ['vqgan','evqgan','gvqgan','vqvae','evqvae','gvqvae','vqvae2']
    # training
    epochs=30 #@param {'type': 'raw'}
    learning_rate=4.5e-6 #@param {'type': 'number' }
    precision=16 #@param {'type': 'integer' }
    batch_size=8 #@param {'type': 'raw'}
    num_workers=8 #@param {'type': 'raw'} 
    # fake_data=True #@param {'type': 'boolean' }
    use_tpus=True #@param {'type': 'boolean' }
    
    
    # modifiable
    resume=False #@param {type: 'boolean'}
    dropout=0.1 #@param {type: 'number'}
    rescale_img_size=256 #@param {type: 'number'}
    resize_ratio=0.75 #@param {type: 'number'}
    # test=True #@param {type: 'boolean'}
    seed=8675309
    codebook_dim=1024
    embedding_dim=256
    
    python '/content/dalle-lightning-modified-/train_dalle.py' \
        --epochs $epochs \
        --learning_rate $learning_rate \
        --precision $precision \
        --batch_size $batch_size \
        --num_workers $num_workers \
        --use_tpus \
        --train_dir "/content/data/train/" \
        --val_dir "/content/data/test" \
        --vae_path "/content/vae_logs/last.ckpt"  \
        --log_dir "/content/dalle_logs/" \
        --img_size $rescale_img_size \
        --seed $seed \
        --resize_ratio $resize_ratio \
        --embedding_dim $embedding_dim \
        --codebook_dim $codebook_dim
    

    When running I get the following error:

    WARNING:root:TPU has started up successfully with version pytorch-1.9
    Global seed set to 8675309
    GPU available: False, used: False
    TPU available: True, using: 8 TPU cores
    IPU available: False, using: 0 IPUs
    Setting batch size: 8 learning rate: 4.50e-06
    
    Global seed set to 8675309
    Global seed set to 8675309
    Global seed set to 8675309
    Global seed set to 8675309
    Global seed set to 8675309
    Global seed set to 8675309
    Global seed set to 8675309
    Global seed set to 8675309
    
      | Name          | Type                     | Params
    -----------------------------------------------------------
    0 | text_emb      | Embedding                | 5.3 M 
    1 | image_emb     | Embedding                | 4.2 M 
    2 | text_pos_emb  | Embedding                | 131 K 
    3 | image_pos_emb | AxialPositionalEmbedding | 32.8 K
    4 | vae           | OpenAIDiscreteVAE        | 97.6 M
    5 | transformer   | Transformer              | 268 M 
    6 | to_logits     | Sequential               | 9.5 M 
    -----------------------------------------------------------
    288 M     Trainable params
    97.6 M    Non-trainable params
    385 M     Total params
    771.301   Total estimated model params size (MB)
    Exception in device=TPU:5: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    Exception in device=TPU:3: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    Exception in device=TPU:7: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    Exception in device=TPU:0: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    Exception in device=TPU:1: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    Exception in device=TPU:2: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    Exception in device=TPU:6: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    Exception in device=TPU:4: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
    

    One process sees the folder of 94629 images and texts and the rest see 0 images and texts. I do not understand why this is happening. Could you please help me with this? Any idea?

    opened by mkhoshle 9
  • TypeError: Expected 'Iterator' as the return annotation for `__iter__` of ExperienceSourceDataset, but found typing.Iterable Hi, i'm using tpus and i have also tried gpus. but this error still persists. and yes i'm using the newest version.

    TypeError: Expected 'Iterator' as the return annotation for `__iter__` of ExperienceSourceDataset, but found typing.Iterable Hi, i'm using tpus and i have also tried gpus. but this error still persists. and yes i'm using the newest version.

    WARNING:root:Waiting for TPU to be start up with version pytorch-1.9... WARNING:root:Waiting for TPU to be start up with version pytorch-1.9... WARNING:root:TPU has started up successfully with version pytorch-1.9 Traceback (most recent call last): File "/content/lib/dalle-lightning/train_vae.py", line 17, in from pl_dalle.callbacks import ReconstructedImageLogger File "/content/lib/dalle-lightning/pl_dalle/callbacks.py", line 8, in from pl_bolts.utils import _TORCHVISION_AVAILABLE File "/usr/local/lib/python3.7/dist-packages/pl_bolts/init.py", line 19, in from pl_bolts import ( # noqa: E402 File "/usr/local/lib/python3.7/dist-packages/pl_bolts/datamodules/init.py", line 5, in from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset File "/usr/local/lib/python3.7/dist-packages/pl_bolts/datamodules/experience_source.py", line 24, in class ExperienceSourceDataset(IterableDataset): File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_typing.py", line 273, in new return super().new(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] File "/usr/lib/python3.7/abc.py", line 126, in new cls = super().new(mcls, name, bases, namespace, **kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_typing.py", line 371, in _dp_init_subclass ", but found {}".format(sub_cls.name, _type_repr(hints['return']))) TypeError: Expected 'Iterator' as the return annotation for __iter__ of ExperienceSourceDataset, but found typing.Iterable

    opened by jordanvalter 6
  • num_workers causes train_vae.py to hang indefinitely

    num_workers causes train_vae.py to hang indefinitely

    Hello! I've recently attempted to train a VAE on a TPU v3-8 VM and it hangs during validation. I've fixed it by setting args.num_workers to 0, but of course, it makes training much slower.

    Traceback (most recent call last):██████████████████| 7/7 [02:08<00:00, 29.00s/it]
      File "train_vae.py", line 259, in <module>
        trainer.fit(model, datamodule=datamodule)
      File "/home/haru/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
        self._run(model)
      File "/home/haru/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
        self._dispatch()
      File "/home/haru/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1003, in _dispatch
        self.accelerator.start_training(self)
      File "/home/haru/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
        self.training_type_plugin.start_training(trainer)
      File "/home/haru/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/tpu_spawn.py", line 267, in start_training
        xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)
      File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 388, in spawn
        return torch.multiprocessing.start_processes(
      File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
        while not context.join():
      File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 99, in join
        ready = multiprocessing.connection.wait(
      File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
        ready = selector.select(timeout)
      File "/usr/lib/python3.8/selectors.py", line 415, in select
        fd_event_list = self._selector.poll(timeout)
    
    documentation 
    opened by harubaru 5
  • Bad Iterator Type

    Bad Iterator Type

    Just got this one trying to train on the VQGAN with GPUs again. edit: seems to be occurring with the VQVAE2 as well.

      RequestsDependencyWarning)
    Traceback (most recent call last):
      File "/content/lib/dalle-lightning/train_vae.py", line 22, in <module>
        from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler
      File "/usr/local/lib/python3.7/dist-packages/pl_bolts/__init__.py", line 19, in <module>
        from pl_bolts import (  # noqa: E402
      File "/usr/local/lib/python3.7/dist-packages/pl_bolts/datamodules/__init__.py", line 5, in <module>
        from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset
      File "/usr/local/lib/python3.7/dist-packages/pl_bolts/datamodules/experience_source.py", line 24, in <module>
        class ExperienceSourceDataset(IterableDataset):
      File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_typing.py", line 273, in __new__
        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
      File "/usr/lib/python3.7/abc.py", line 126, in __new__
        cls = super().__new__(mcls, name, bases, namespace, **kwargs)
      File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_typing.py", line 371, in _dp_init_subclass
        ", but found {}".format(sub_cls.__name__, _type_repr(hints['return'])))
    TypeError: Expected 'Iterator' as the return annotation for `__iter__` of ExperienceSourceDataset, but found typing.Iterable
    
    bug 
    opened by afiaka87 5
  • Colab Notebook (WIP)

    Colab Notebook (WIP)

    Hey I made a colab notebook to try and test this repo out on GPU. Seems to work on TPU in my testing which is good to know. Lots of commented out blocks currently as I mostly wanted to get the base notebook written out quickly. Feel free to use any/none of it.

    https://github.com/afiaka87/dalle-lightning/blob/notebook/dalle_lightning.ipynb

    @robvanvolt may have interest.

    opened by afiaka87 5
  • VQGAN training breaks on GPU

    VQGAN training breaks on GPU

      File "/content/lib/dalle-lightning/pl_dalle/modules/vqvae/quantize.py", line 20, in forward
        z_flattened = z.view(-1, self.embedding_dim)
    RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
    

    Works fine on TPU, breaks on GPU. Will reshape work on TPU?

    opened by afiaka87 3
  • VQGAN callback breaks

    VQGAN callback breaks

    Getting an error when training vqgan

    https://github.com/tgisaturday/dalle-lightning/blob/b4164106fee0d25ad0312ff1cbc24516c78b905d/pl_dalle/callbacks.py#L70

    Modifying to grab the first element beforehand fixes the issue.

        ) -> None:
            """Called when the train batch ends."""
            if trainer.global_step % self.every_n_steps == 0:
                # x = outputs['x']
                x = outputs[0]['x']
                xrec = outputs[0]['xrec']
    
    opened by afiaka87 2
  • vqgan working now (except there's a small bug)

    vqgan working now (except there's a small bug)

    Great work with writing out the pl_bolts requirement. So far I can at least confirm that the VAE training works now on GPU in colab. At least, for vqgan.

    For some reason the args.gpus flag isn't quite making it to the Trainer? Couldn't quite figure out what was wrong and I'm a bit scattered at the moment.

    Anyway; small benefit of that is that I tested it on CPU accident which works as well.

    opened by afiaka87 2
  • webdataset bugfixes, optional wandb logging for VAE training, VQGAN discriminator warmup + spectral norm

    webdataset bugfixes, optional wandb logging for VAE training, VQGAN discriminator warmup + spectral norm

    Had a bunch of pieces that I wanted to change for my personal VQGAN training run, not sure if you want all of them but I'm pretty sure they're all positive.

    For VQGAN, hinge loss means that there's not much reason not to train the discriminator in the early steps, before we start to backprop it through the generator, especially since we're going to do the forward passes anyway with the way the current training loop is set up. I also added spectral normalization to the discriminator, to keep the discriminator-induced gradients from exploding too hard.

    The other changes are more strictly positive - there were some bugs in webdataset loading leading to input being batches of batches instead of just, well, batches. These changes might need some additional patching to support conditional generators / discriminators, though since the original code didn't work at all I still see this as a strict improvement.

    I also added optional wandb logging, which is conditionally imported only if the --wandb argument is passed to train_vae.py, which both logs gradients of the VQGAN and logs reconstructed images in a table instead of in separate grids.

    opened by GallagherCommaJack 2
Owner
Kim, Taehoon
Research Scientist & Machine Learning Engineer.
Kim, Taehoon
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 21.3k Jan 1, 2023
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 11.4k Feb 13, 2021
A minimal TPU compatible Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

NeRF Minimal Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis. Result of Tiny-NeRF RGB Depth

Soumik Rakshit 11 Jul 24, 2022
The full training script for Enformer (Tensorflow Sonnet) on TPU clusters

Enformer TPU training script (wip) The full training script for Enformer (Tensorflow Sonnet) on TPU clusters, in an effort to migrate the model to pyt

Phil Wang 10 Oct 19, 2022
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

QData 440 Jan 2, 2023
ChatBot-Pytorch - A GPT-2 ChatBot implemented using Pytorch and Huggingface-transformers

ChatBot-Pytorch A GPT-2 ChatBot implemented using Pytorch and Huggingface-transf

ParZival 42 Dec 9, 2022
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Tested on many Common CNN Networks and Vision Transformers. ⭐ Includes smoo

Jacob Gildenblat 6.6k Jan 6, 2023
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Hila Chefer 489 Jan 7, 2023
Load What You Need: Smaller Multilingual Transformers for Pytorch and TensorFlow 2.0.

Smaller Multilingual Transformers This repository shares smaller versions of multilingual transformers that keep the same representations offered by t

Geotrend 79 Dec 28, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow ?? Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 2, 2023
Pytorch Implementation of Various Point Transformers

Pytorch Implementation of Various Point Transformers Recently, various methods applied transformers to point clouds: PCT: Point Cloud Transformer (Men

Neil You 434 Dec 30, 2022
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Phil Wang 207 Dec 23, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 4, 2023
Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch

Omninet - Pytorch Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch. The authors propose that we should be atte

Phil Wang 48 Nov 21, 2022
PyTorch Implementation of CvT: Introducing Convolutions to Vision Transformers

CvT: Introducing Convolutions to Vision Transformers Pytorch implementation of CvT: Introducing Convolutions to Vision Transformers Usage: img = torch

Rishikesh (ऋषिकेश) 193 Jan 3, 2023
PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Self-Supervised Vision Transformers with DINO PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supe

Facebook Research 4.2k Jan 3, 2023
This repository contains PyTorch code for Robust Vision Transformers.

This repository contains PyTorch code for Robust Vision Transformers.

null 117 Dec 7, 2022
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Phil Wang 383 Jan 2, 2023
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

null 73 Jan 1, 2023