Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Overview

DALL-E in Pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the generations.

Sid, Ben, and Aran over at Eleuther AI are working on DALL-E for Mesh Tensorflow! Please lend them a hand if you would like to see DALL-E trained on TPUs.

Yannic Kilcher's video

Before we replicate this, we can settle for Deep Daze or Big Sleep

Status

Hannu has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens)

Install

$ pip install dalle-pytorch

Usage

Train VAE

import torch
from dalle_pytorch import DiscreteVAE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,          # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,       # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,      # codebook dimension
    hidden_dim = 64,         # hidden dimension
    num_resnet_blocks = 1,   # number of resnet blocks
    temperature = 0.9,       # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False # straight-through for gumbel softmax. unclear if it is better one way or the other
)

images = torch.randn(4, 3, 256, 256)

loss = vae(images, return_recon_loss = True)
loss.backward()

# train with a lot of data to learn a good codebook

Train DALL-E with pretrained VAE from above

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text, mask = mask)
images.shape # (2, 3, 256, 256)

Ranking the generations

Train CLIP

import torch
from dalle_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

To get the similarity scores from your trained Clipper, just do

images, scores = dalle.generate_images(text, mask = mask, clip = clip)

scores.shape # (2,)
images.shape # (2, 3, 256, 256)

# do your topk here, in paper they sampled 512 and chose top 32

Or you can just use the official CLIP model to rank the images from DALL-E

Scaling depth

In the blog post, they used 64 layers to achieve their results. I added reversible networks, from the Reformer paper, in order for users to attempt to scale depth at the cost of compute. Reversible networks allow you to scale to any depth at no memory cost, but a little over 2x compute cost (each layer is rerun on the backward pass).

Simply set the reversible keyword to True for the DALLE class

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True  # <-- reversible networks https://arxiv.org/abs/2001.04451
)

Sparse Attention

You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.

First, you need to install Deepspeed with Sparse Attention

$ sh install_deepspeed.sh

Next, you need to install the pip package triton

$ pip install triton

If both of the above succeeded, now you can train with Sparse Attention!

dalle = DALLE(
    dim = 512,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 8,
    sparse_attn = (True, False) * 32  # interleave sparse and dense attention for 64 layers
)

Citations

@misc{unpublished2021dalle,
    title   = {DALL·E: Creating Images from Text},
    author  = {Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray},
    year    = {2021}
}
@misc{unpublished2021clip,
    title  = {CLIP: Connecting Text and Images},
    author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
    year   = {2021}
}
@misc{kitaev2020reformer,
    title   = {Reformer: The Efficient Transformer},
    author  = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
    year    = {2020},
    eprint  = {2001.04451},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

Those who do not want to imitate anything, produce nothing. - Dali

Comments
  • VQGanVAE1024

    VQGanVAE1024 "vae must be an instance of DiscreteVAE"

    @lucidrains I believe the relevant line is here:

    https://github.com/lucidrains/DALLE-pytorch/blob/2268864941d8eef2ba73a4488fe05673d447d493/dalle_pytorch/dalle_pytorch.py#L306

    I tried adding it in myself, but it needs the taming imports and I'm not familiar with those.

    opened by afiaka87 64
  • More

    More "OpenAI Blog Post" Training | Depth 32 | Heads 8 | LR 5e-4

    Edit: Moved to discussions: https://github.com/lucidrains/DALLE-pytorch/discussions/106

    Hey, all. Some of you might know I'm practicing and learning about machine learning with dalle-pytorch and a dataset consisting of the images OpenAI presented in the DALLE blog post. I honestly dont have the money to train this whole dataset,

    edit: this is no longer true. Using the 1024 VQGAN from the "Taming Transformers" research, it's now quite possible to train a full dataset of 1,000,000 image-text pairs and i'm doing just that. I hope to have it finished in about a week. I assume someone else will release a dalle-pytorch trained properly on COCO and other image sets before then, but if they dont, check here for updates.

    Anway, it ran for ~36000 steps. As you can see it...still really likes mannequins. I'm considering removing them from the dataset. But also, you'll notice that the network has actually got a decent idea of the sort of general colors that belong in types of prompts.

    Some Samples from Near the End of Training

    results

    Every Text-Image Reconstruction

    https://wandb.ai/afiaka87/dalle_pytorch_live_training/reports/dalle-pytorch-Test-Run-2--Vmlldzo1MzM5MjQ

    Deliverables (my train_dalle.py)

    https://gist.github.com/afiaka87/850fb3cc48edde8a7ed4cb1ce53b6bd2

    This has some code in it that actually manages to deal with truncated images via Try Catch. Apparently detecting a corrupted PNG is harder than P vs NP. PIL's imverify() function doesnt catch all of them. Python's built in imghdr library doesn't catch all of them either. So you just sort of catch OSError and return an item further along. Works well enough.

    Parameters

    SHUFFLE = True
    EPOCHS = 28 # This wound up being less than a single epoch, of course. 
    BATCH_SIZE = 16
    LEARNING_RATE = 0.0005 # I found this learning rate to be more suitable than 0.0003 in my hyperparameter sweep post
    GRAD_CLIP_NORM = 0.5
    DEPTH = 32
    HEADS = 8
    MODEL_DIM = 512
    TEXT_SEQ_LEN = 256
    DIM_HEAD = 64
    REVERSIBLE = True,
    ATTN_TYPES = ('full')
    

    Dataset Description

    https://github.com/lucidrains/DALLE-pytorch/issues/61#issuecomment-796663342

    Just for more info on the dataset itself, it is roughly 1,100,000 256x256 image-text pairs that were generated by OpenAI's DALL-E. They presented roughly ~30k unique text prompts of which they posted the top 32 of 512 generations on https://openai.com/blog/dall-e/. Many images were corrupt, and not every prompt has a full 32 examples, but the total number of images winds up being about 1.1 million. If you look at many of the examples on that page, you'll see that DALL-E (in that form at least), can and will make mistakes. These mistakes are also in this dataset. Anyway I'm just messing around having fun training and what not. This is definitely not going to produce a good model or anything.

    There are also a large number of images in the dataset which are intended to be used with the "mask" feature. I don't know if that's possible yet in DALLE-pytorch though. Anyway, that can't be helping much.

    opened by afiaka87 31
  • Huge GPU memory consumption when using DeepSpeed

    Huge GPU memory consumption when using DeepSpeed

    So I decided to try the recently introduced DeepSpeed training on 4x V100 GPUs. Previously I was training on a single T4 or V100 (both have 16GB RAM) with certain model parameters (1024 model dim, 16 layers, 8 heads, 64 head dim if that's important). I was able to use a batch size of 2 with this configuration (256px images and no texts).

    I tried to launch distributed training with DeepSpeed with the same model parameters, but to my surprise, it gave an OOM error. Using binary search I found that it's possible to train the model using DeepSpeed but only when I set the number of layers to 4 (4x reduction) and use only a single sample per batch per GPU (so, 2x reduction).

    Am I missing something? I'm using the latest master branch with some minor changes that are very unlikely to cause such behavior.

    Any help/suggestion is very much appreciated!

    opened by ex4sperans 29
  • Add tensorboard support.

    Add tensorboard support.

    Weights and biases is cool - but they don't really support the offline usecase as well as tensorboard does currently. This fix simply adds tensorboard as a fallback option - W&B will still run if it's available.

    opened by afiaka87 22
  • Trained for 17K iters on COCO2014, OpenImages and OpenAI's Blog Images

    Trained for 17K iters on COCO2014, OpenImages and OpenAI's Blog Images

    In case you haven't read my usual disclaimer: this data set is weird. The repetition in the OpenAI images causes those to be highly overfit (mannequins) and the remainder of the dataset is much more diverse, which dalle-pytorch doesnt manage to capture very well here. Also, keep in mind - this isn't even a full epoch. Just having fun. Try not to evaluate this as representative of dalle-pytorch's current capabilities.

    closetotheend

    Hey everyone. @lucidrains got the the new, lighter pretrained VAE from the taming-transformers group recently. It uses substantially less memory and compute. I decided to take all the datasets ive collected thus far, put them in a single folder on an A100, and train dalle-pytorch for several hours.

    Here are the results:

    https://wandb.ai/afiaka87/OpenImagesV6/reports/Training-on-COCO-OpenImage-Blogpost--Vmlldzo1NDE3NjU

    I'm exhausted so that's all for now, but please click the link and have a look at the thousands of reconstructions it made (and the horrible captions from the "Localized Narratives" dataset I got from Google). I'll be updating this post with more info throughout the day.

    opened by afiaka87 19
  • New error using the new update.

    New error using the new update.

    [2021-09-13 11:39:11,114] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed info: version=0.5.1, git-hash=unknown, git-branch=unknown [2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed groups [2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed model parallel group with size 1 [2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed expert parallel group with size 1 [2021-09-13 11:39:11,217] [INFO] [logging.py:68:log_dist] [Rank 0] creating expert data parallel process group with ranks: [0] [2021-09-13 11:39:11,217] [INFO] [logging.py:68:log_dist] [Rank 0] creating expert parallel process group with ranks: [0] [2021-09-13 11:39:11,240] [INFO] [engine.py:198:init] DeepSpeed Flops Profiler Enabled: False Traceback (most recent call last): File "train_dalle.py", line 497, in config_params=deepspeed_config, File "/home/valterjordan/DALLE-pytorch/dalle_pytorch/distributed_backends/distributed_backend.py", line 152, in distribute **kwargs, File "/home/valterjordan/DALLE-pytorch/dalle_pytorch/distributed_backends/deepspeed_backend.py", line 162, in _distribute **kwargs, File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/init.py", line 141, in initialize config_params=config_params) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 204, in init self.training_dataloader = self.deepspeed_io(training_data) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1188, in deepspeed_io data_parallel_rank=data_parallel_rank) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/dataloader.py", line 52, in init rank=data_parallel_rank) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/torch/utils/data/distributed.py", line 87, in init self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore TypeError: object of type 'Processor' has no len()

    opened by jaoeded 18
  • Ideas for datasets to uses? (WIT just came out)

    Ideas for datasets to uses? (WIT just came out)

    Hey all,

    I'm compiling a list of the various datasets we'll need and how to download them:

    Keep in mind, not all of these datasets ship with captions. However many of them do ship with a class descriptor of some type. I've only done mild testing with this, but usually you can just generate labels by doing something like "an image of {class_name}". Not sure what the best way to go about that would be though.

    https://github.com/lucidrains/DALLE-pytorch/discussions/109

    As it stands, this is turning out to be humongous. I just added the new Wikipedia dataset (11 million images).

    Does anyone know of other captioned datasets we could use?

    opened by afiaka87 17
  • refactor

    refactor

    We need a refactor. Maybe a redesign (breaking backwards compatibility). The codebase has accrued just a little too much tech debt. Particularly in the train_dalle.py and README.md

    Training loop

    • The training loop needs to be refactored into a function in the style of pytorch lightning. This dramatically improves readability and makes early returns possible which is a cleaner solution than raising an Exception.

    Argument parsing

    • All arg-parsing need to be well-consolidated into a single file. Perhaps another file at the root of the repo as it's still something people are going to want to open up quickly.

    Style guide and automatic formatting in CI

    • This one's easy - we need a consistent target format for multiple contributors to target. I like PEP8 but honestly if we just pick one I'll be happy we're using a linter/formatter at all. Alternatives are black/yapf. I could integrate this into Github Actions so that code is automatically formatted upon merge to main.

    Configuration:

    • we should provide reasonable defaults deepspeed_config.json in a folder at the root called "config" and encourage people to update that file rather than update the deepspeed_config variable in Python which is a bit of a hassle to get to and find each time. A comment at the top of each file with a links to:
    • https://deepspeed.readthedocs.io/en/latest/schedulers.html # Various (otherwise) undocumented schedulers from DeepSpeed
    • https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training # Options specific to zero.

    README.md

    • The readme is a bit long; but that's alright. I prefer the idea of a single source of truth with documentation; especially one that doesn't require the internet after you've downloaded it e.g. the wiki
    • It just needs a table of contents and a "quick start" at the top. I do like the idea of continuing to give progress updates on the current best at the top as well. So let's keep that. taming-transformers has a wonderfully clean README.md with a table of contents which could be used as a basis.

    Feel free to criticize the current design (please be constructive and polite - we're all lazy coders at the end of the day) so we can decide if breaking backwards compatibility for a new design would be worthwhile.

    opened by afiaka87 12
  •  DALL-E Image Embedding

    DALL-E Image Embedding

    A token is any symbol from a discrete vocabulary; for humans, each English letter is a token from a 26-letter alphabet. DALL·E’s vocabulary has tokens for both text and image concepts. Specifically, each image caption is represented using a maximum of 256 BPE-encoded tokens with a vocabulary size of 16384, and the image is represented using 1024 tokens with a vocabulary size of 8192. The images are preprocessed to 256x256 resolution during training. Similar to VQVAE, each image is compressed to a 32x32 grid of discrete latent codes using a discrete VAE that we pretrained using a continuous relaxation. We found that training using the relaxation obviates the need for an explicit codebook, EMA loss, or tricks like dead code revival, and can scale up to large vocabulary sizes.

    We can use openAI CLIP implementation to filter the good samples, but I would assume they didn*t used it to create the embedding. So therefore we could assume they used some kind of VQ-VAE? For example https://github.com/openai/vdvae or https://github.com/NVlabs/NVAE ?

    So this GIT should have 2-step Training Step 1 - Pretrained a autoencoder to tokenize the images. We could go small first and do it with a 16x16 Embedding and a relatively low vocab size. (2k-4k?) Step 2 - Train the Decoder-Transformer. Here we should have a preprocessing step to convert the image-text pairs to tokens. Some Huggingface tokenizer for Text and the encoder of VQ-VAE for the image.

    We hope that someone will offer a pretrained model weights for CLIP to remove bad samples during Inference. If it was trained on something like the Microsoft Dataset, then it should be general enough for most usecases.

    Some Open Questions:

    • They use Sparse Attention for the Image Part. We could just use full-attention for the whole network for now or go full sparse?
    • If its not a VQ-VAE, which GANs work well with discrete latent values?
    • If its VQ-VAE, its some kind of Hierarchical one. Does DALL-E model the first latent value and the rest is just randomly sampled during reconstructions?
    opened by adrian-spataru 12
  • stable_softmax, wanb_entity, visible discord, replace buggy colab

    stable_softmax, wanb_entity, visible discord, replace buggy colab

    edit: alright rom1504 is being awesome and implementing things the proper modular way for us. I'm gonna focus this PR on a few outstanding issues

    Seems the CompVis team hasn't updated their PyPi because their latest pip wheel still doesn't contain the necessary GumbelVQ class. I've had to install this as a submodule to taming-transformers to get it to work which doesnt feel quite right.

    opened by afiaka87 11
  • Out of memory errors no matter what parameters with deep speed

    Out of memory errors no matter what parameters with deep speed

    Using these fairly lightweight parameters:

    BATCH_SIZE = 8
    LEARNING_RATE = 3e-4
    
    MODEL_DIM = 512
    TEXT_SEQ_LEN = 128
    DEPTH = 4
    HEADS = 4
    DIM_HEAD = 64
    REVERSIBLE = True
    LOSS_IMG_WEIGHT = 7
    

    A single V100 GPU only needs 6356MB of RAM.

    [0] Tesla V100-SXM2-16GB | 57'C, 81 % | 6356 / 16160 MB |

    When run with deepspeed - memory usage immediately balloons to filling up each GPU's 16 GiB of RAM until finally running out of memory before a single iteration completes.

    Aside - please dont take these personal ha - we have pinned versions and what not - just trying to be thorough so I can come back and try to fix them myself.

    Traceback (most recent call last): File "train_dalle.py", line 271, in loss = distr_dalle(text, images, mask = mask, return_loss = True) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 914, in forward loss = self.module(*inputs, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/root/DALLE-pytorch/dalle_pytorch/dalle_pytorch.py", line 495, in forward loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:], ignore_index=0) File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 2422, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 1591, in log_softmax ret = input.log_softmax(dim) RuntimeError: CUDA out of memory. Tried to allocate 394.00 MiB (GPU 0; 15.78 GiB total capacity; 1.80 GiB already allocated; 178.75

    opened by afiaka87 11
  • Cant run example models in colab due to lightning error

    Cant run example models in colab due to lightning error

    Hi, I've tried a few of the notebooks you provided for the examples but I run in to the same error across all of them in colab with the following error:

    ImportError                               Traceback (most recent call last)
    [<ipython-input-6-548ac97a7512>](https://localhost:8080/#) in <module>
         15 # dalle classes
         16 
    ---> 17 from dalle_pytorch import DiscreteVAE
         18 
         19 # constants
    
    4 frames
    [/usr/local/lib/python3.7/dist-packages/taming/main.py](https://localhost:8080/#) in <module>
         10 from pytorch_lightning.trainer import Trainer
         11 from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
    ---> 12 from pytorch_lightning.utilities.distributed import rank_zero_only
         13 
         14 def get_obj_from_str(string, reload=False):
    
    ImportError: cannot import name 'rank_zero_only' from 'pytorch_lightning.utilities.distributed' (/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/distributed.py)
    
    

    I think the version of lightning that is installed by the script might be incorrect now?

    Any help getting this fixed is greatly appreciated!

    Thanks

    opened by neramas1221 0
  • DALLE trained on FashionGen Dataset RESULTS 💯

    DALLE trained on FashionGen Dataset RESULTS 💯

    DALLE on FashionGen

    • I trained Dall-E + VQGAN on the FashionGen dataset (https://arxiv.org/abs/1806.08317) on Google Colab and got decent results.
    • Without the VQGAN training on the FashionGen dataset, DALLE is really bad at generating faces which makes clothing generations looking extremely strange.

    Text to image generation and re-ranking by CLIP

    Best 16 of 48 generations ranked by CLIP

    Generations from the training set (Including their Groundtruths)

    Download (5) Download (6) Download (7) Download (8) Download (4)

    Generations based on custom prompts (withouttheir Groundtruths)

    Download (1) Download (2) Download (3) Download (9) Download

    Model specifications

    VAE Trained VQGAN for 1 epoch on Fashion-Gen dataset Embeddings: 1024 Batch size: 5

    DALLE Trained DALLE for 1 epoch on Fashion-Gen dataset dim = 312 text_seq_len = 80 depth = 36 heads = 12 dim_head = 64 reversible = 0 attn_types =('full', 'axial_row', 'axial_col', 'conv_like')

    Optimization Optimizer: Adam Learning rate: 4.5e-4 Gradient Clipping: 0.5 Batch size: 7

    image

    opened by alexriedel1 8
  • Text transformers

    Text transformers

    Hi again :) Is there any way to change the transformer architecture easily as in x-clip ? I would like to use my own ( which is pretrained ) :) Thanks !

    opened by ethancohen123 1
  • dvae training resulting in an irregular latent space

    dvae training resulting in an irregular latent space

    Hello! I'm not sure whether this should be raised as an issue or it is a fault completely on my side. But I've reached a point where I can't seem to figure it out on my own, so I hope someone could enlighten me.

    I'm trying to train the DiscreteVAE with some custom dataset, but the trained model seems to fail in learning a regular latent space.

    For instance, when I generate from a codebook index decoded from one of my dataset images, the output image seems fine, but when I try to interpolate between two indices, the latent space between two indices result in completely unrecognizable images.

    I am told that the kl loss value have something to do with the regularizing of the latent space, but according to some issues raised before, this does not seem to be a usable option.

    Is there a known reason for this kind of irregularity in the latent space? Or rather, has anyone succeeded in smooth latent interpolation while training DiscreteVAE model? It would be really helpful if someone has succeeded and can tell me about the relevant parameters.

    opened by hlp-pls 0
Releases(1.6.4)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch package for the discrete VAE used for DALL·E.

Overview [Blog] [Paper] [Model Card] [Usage] This is the official PyTorch package for the discrete VAE used for DALL·E. Installation Before running th

OpenAI 9.5k Jan 5, 2023
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Jax

ProGen - (wip) Implementation and replication of ProGen, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily

Phil Wang 71 Dec 1, 2022
Open-AI's DALL-E for large scale training in mesh-tensorflow.

DALL-E in Mesh-Tensorflow [WIP] Open-AI's DALL-E in Mesh-Tensorflow. If this is similarly efficient to GPT-Neo, this repo should be able to train mode

EleutherAI 432 Dec 16, 2022
Replication attempt for the Protein Folding Model

RGN2-Replica (WIP) To eventually become an unofficial working Pytorch implementation of RGN2, an state of the art model for MSA-less Protein Folding f

Eric Alcaide 36 Nov 29, 2022
Replication of Pix2Seq with Pretrained Model

Pretrained-Pix2Seq We provide the pre-trained model of Pix2Seq. This version contains new data augmentation. The model is trained for 300 epochs and c

peng gao 51 Nov 22, 2022
Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

null 61 Jan 1, 2023
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).

Deep Text Search - AI Based Text Search & Recommendation System Deep Text Search is an AI-powered multilingual text search and recommendation engine w

null 19 Sep 29, 2022
Pytorch re-implementation of Paper: SwinTextSpotter: Scene Text Spotting via Better Synergy between Text Detection and Text Recognition (CVPR 2022)

SwinTextSpotter This is the pytorch implementation of Paper: SwinTextSpotter: Scene Text Spotting via Better Synergy between Text Detection and Text R

mxin262 183 Jan 3, 2023
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

Jiezhang Cao 225 Nov 13, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
A 1.3B text-to-image generation model trained on 14 million image-text pairs

minDALL-E on Conceptual Captions minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for no

Kakao Brain 604 Dec 14, 2022
Codes to pre-train T5 (Text-to-Text Transfer Transformer) models pre-trained on Japanese web texts

t5-japanese Codes to pre-train T5 (Text-to-Text Transfer Transformer) models pre-trained on Japanese web texts. The following is a list of models that

Kimio Kuramitsu 1 Dec 13, 2021
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 6, 2022
Facebook Research 605 Jan 2, 2023
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

null 87 Nov 29, 2022