Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Overview

DALL-E in Pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the generations.

Sid, Ben, and Aran over at Eleuther AI are working on DALL-E for Mesh Tensorflow! Please lend them a hand if you would like to see DALL-E trained on TPUs.

Yannic Kilcher's video

Before we replicate this, we can settle for Deep Daze or Big Sleep

Status

Hannu has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens)

Install

$ pip install dalle-pytorch

Usage

Train VAE

import torch
from dalle_pytorch import DiscreteVAE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,          # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,       # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,      # codebook dimension
    hidden_dim = 64,         # hidden dimension
    num_resnet_blocks = 1,   # number of resnet blocks
    temperature = 0.9,       # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False # straight-through for gumbel softmax. unclear if it is better one way or the other
)

images = torch.randn(4, 3, 256, 256)

loss = vae(images, return_recon_loss = True)
loss.backward()

# train with a lot of data to learn a good codebook

Train DALL-E with pretrained VAE from above

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text, mask = mask)
images.shape # (2, 3, 256, 256)

Ranking the generations

Train CLIP

import torch
from dalle_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

To get the similarity scores from your trained Clipper, just do

images, scores = dalle.generate_images(text, mask = mask, clip = clip)

scores.shape # (2,)
images.shape # (2, 3, 256, 256)

# do your topk here, in paper they sampled 512 and chose top 32

Or you can just use the official CLIP model to rank the images from DALL-E

Scaling depth

In the blog post, they used 64 layers to achieve their results. I added reversible networks, from the Reformer paper, in order for users to attempt to scale depth at the cost of compute. Reversible networks allow you to scale to any depth at no memory cost, but a little over 2x compute cost (each layer is rerun on the backward pass).

Simply set the reversible keyword to True for the DALLE class

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True  # <-- reversible networks https://arxiv.org/abs/2001.04451
)

Sparse Attention

You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.

First, you need to install Deepspeed with Sparse Attention

$ sh install_deepspeed.sh

Next, you need to install the pip package triton

$ pip install triton

If both of the above succeeded, now you can train with Sparse Attention!

dalle = DALLE(
    dim = 512,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 8,
    sparse_attn = (True, False) * 32  # interleave sparse and dense attention for 64 layers
)

Citations

@misc{unpublished2021dalle,
    title   = {DALL·E: Creating Images from Text},
    author  = {Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray},
    year    = {2021}
}
@misc{unpublished2021clip,
    title  = {CLIP: Connecting Text and Images},
    author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
    year   = {2021}
}
@misc{kitaev2020reformer,
    title   = {Reformer: The Efficient Transformer},
    author  = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
    year    = {2020},
    eprint  = {2001.04451},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

Those who do not want to imitate anything, produce nothing. - Dali

Issues
  • VQGanVAE1024

    VQGanVAE1024 "vae must be an instance of DiscreteVAE"

    @lucidrains I believe the relevant line is here:

    https://github.com/lucidrains/DALLE-pytorch/blob/2268864941d8eef2ba73a4488fe05673d447d493/dalle_pytorch/dalle_pytorch.py#L306

    I tried adding it in myself, but it needs the taming imports and I'm not familiar with those.

    opened by afiaka87 64
  • More

    More "OpenAI Blog Post" Training | Depth 32 | Heads 8 | LR 5e-4

    Edit: Moved to discussions: https://github.com/lucidrains/DALLE-pytorch/discussions/106

    Hey, all. Some of you might know I'm practicing and learning about machine learning with dalle-pytorch and a dataset consisting of the images OpenAI presented in the DALLE blog post. I honestly dont have the money to train this whole dataset,

    edit: this is no longer true. Using the 1024 VQGAN from the "Taming Transformers" research, it's now quite possible to train a full dataset of 1,000,000 image-text pairs and i'm doing just that. I hope to have it finished in about a week. I assume someone else will release a dalle-pytorch trained properly on COCO and other image sets before then, but if they dont, check here for updates.

    Anway, it ran for ~36000 steps. As you can see it...still really likes mannequins. I'm considering removing them from the dataset. But also, you'll notice that the network has actually got a decent idea of the sort of general colors that belong in types of prompts.

    Some Samples from Near the End of Training

    results

    Every Text-Image Reconstruction

    https://wandb.ai/afiaka87/dalle_pytorch_live_training/reports/dalle-pytorch-Test-Run-2--Vmlldzo1MzM5MjQ

    Deliverables (my train_dalle.py)

    https://gist.github.com/afiaka87/850fb3cc48edde8a7ed4cb1ce53b6bd2

    This has some code in it that actually manages to deal with truncated images via Try Catch. Apparently detecting a corrupted PNG is harder than P vs NP. PIL's imverify() function doesnt catch all of them. Python's built in imghdr library doesn't catch all of them either. So you just sort of catch OSError and return an item further along. Works well enough.

    Parameters

    SHUFFLE = True
    EPOCHS = 28 # This wound up being less than a single epoch, of course. 
    BATCH_SIZE = 16
    LEARNING_RATE = 0.0005 # I found this learning rate to be more suitable than 0.0003 in my hyperparameter sweep post
    GRAD_CLIP_NORM = 0.5
    DEPTH = 32
    HEADS = 8
    MODEL_DIM = 512
    TEXT_SEQ_LEN = 256
    DIM_HEAD = 64
    REVERSIBLE = True,
    ATTN_TYPES = ('full')
    

    Dataset Description

    https://github.com/lucidrains/DALLE-pytorch/issues/61#issuecomment-796663342

    Just for more info on the dataset itself, it is roughly 1,100,000 256x256 image-text pairs that were generated by OpenAI's DALL-E. They presented roughly ~30k unique text prompts of which they posted the top 32 of 512 generations on https://openai.com/blog/dall-e/. Many images were corrupt, and not every prompt has a full 32 examples, but the total number of images winds up being about 1.1 million. If you look at many of the examples on that page, you'll see that DALL-E (in that form at least), can and will make mistakes. These mistakes are also in this dataset. Anyway I'm just messing around having fun training and what not. This is definitely not going to produce a good model or anything.

    There are also a large number of images in the dataset which are intended to be used with the "mask" feature. I don't know if that's possible yet in DALLE-pytorch though. Anyway, that can't be helping much.

    opened by afiaka87 31
  • Huge GPU memory consumption when using DeepSpeed

    Huge GPU memory consumption when using DeepSpeed

    So I decided to try the recently introduced DeepSpeed training on 4x V100 GPUs. Previously I was training on a single T4 or V100 (both have 16GB RAM) with certain model parameters (1024 model dim, 16 layers, 8 heads, 64 head dim if that's important). I was able to use a batch size of 2 with this configuration (256px images and no texts).

    I tried to launch distributed training with DeepSpeed with the same model parameters, but to my surprise, it gave an OOM error. Using binary search I found that it's possible to train the model using DeepSpeed but only when I set the number of layers to 4 (4x reduction) and use only a single sample per batch per GPU (so, 2x reduction).

    Am I missing something? I'm using the latest master branch with some minor changes that are very unlikely to cause such behavior.

    Any help/suggestion is very much appreciated!

    opened by ex4sperans 29
  • Add tensorboard support.

    Add tensorboard support.

    Weights and biases is cool - but they don't really support the offline usecase as well as tensorboard does currently. This fix simply adds tensorboard as a fallback option - W&B will still run if it's available.

    opened by afiaka87 22
  • Trained for 17K iters on COCO2014, OpenImages and OpenAI's Blog Images

    Trained for 17K iters on COCO2014, OpenImages and OpenAI's Blog Images

    In case you haven't read my usual disclaimer: this data set is weird. The repetition in the OpenAI images causes those to be highly overfit (mannequins) and the remainder of the dataset is much more diverse, which dalle-pytorch doesnt manage to capture very well here. Also, keep in mind - this isn't even a full epoch. Just having fun. Try not to evaluate this as representative of dalle-pytorch's current capabilities.

    closetotheend

    Hey everyone. @lucidrains got the the new, lighter pretrained VAE from the taming-transformers group recently. It uses substantially less memory and compute. I decided to take all the datasets ive collected thus far, put them in a single folder on an A100, and train dalle-pytorch for several hours.

    Here are the results:

    https://wandb.ai/afiaka87/OpenImagesV6/reports/Training-on-COCO-OpenImage-Blogpost--Vmlldzo1NDE3NjU

    I'm exhausted so that's all for now, but please click the link and have a look at the thousands of reconstructions it made (and the horrible captions from the "Localized Narratives" dataset I got from Google). I'll be updating this post with more info throughout the day.

    opened by afiaka87 19
  • Deepspeed fix : save the normal model too

    Deepspeed fix : save the normal model too

    useful to be able to get the model for generation even when using deepspeed

    opened by rom1504 18
  • Added support for webdataset

    Added support for webdataset

    opened by robvanvolt 17
  • Ideas for datasets to uses? (WIT just came out)

    Ideas for datasets to uses? (WIT just came out)

    Hey all,

    I'm compiling a list of the various datasets we'll need and how to download them:

    Keep in mind, not all of these datasets ship with captions. However many of them do ship with a class descriptor of some type. I've only done mild testing with this, but usually you can just generate labels by doing something like "an image of {class_name}". Not sure what the best way to go about that would be though.

    https://github.com/lucidrains/DALLE-pytorch/discussions/109

    As it stands, this is turning out to be humongous. I just added the new Wikipedia dataset (11 million images).

    Does anyone know of other captioned datasets we could use?

    opened by afiaka87 17
  • Can I use Chinese text?

    Can I use Chinese text?

    Can I use Chinese text?

    opened by LIMr1209 15
  •  DALL-E Image Embedding

    DALL-E Image Embedding

    A token is any symbol from a discrete vocabulary; for humans, each English letter is a token from a 26-letter alphabet. DALL·E’s vocabulary has tokens for both text and image concepts. Specifically, each image caption is represented using a maximum of 256 BPE-encoded tokens with a vocabulary size of 16384, and the image is represented using 1024 tokens with a vocabulary size of 8192. The images are preprocessed to 256x256 resolution during training. Similar to VQVAE, each image is compressed to a 32x32 grid of discrete latent codes using a discrete VAE that we pretrained using a continuous relaxation. We found that training using the relaxation obviates the need for an explicit codebook, EMA loss, or tricks like dead code revival, and can scale up to large vocabulary sizes.

    We can use openAI CLIP implementation to filter the good samples, but I would assume they didn*t used it to create the embedding. So therefore we could assume they used some kind of VQ-VAE? For example https://github.com/openai/vdvae or https://github.com/NVlabs/NVAE ?

    So this GIT should have 2-step Training Step 1 - Pretrained a autoencoder to tokenize the images. We could go small first and do it with a 16x16 Embedding and a relatively low vocab size. (2k-4k?) Step 2 - Train the Decoder-Transformer. Here we should have a preprocessing step to convert the image-text pairs to tokens. Some Huggingface tokenizer for Text and the encoder of VQ-VAE for the image.

    We hope that someone will offer a pretrained model weights for CLIP to remove bad samples during Inference. If it was trained on something like the Microsoft Dataset, then it should be general enough for most usecases.

    Some Open Questions:

    • They use Sparse Attention for the Image Part. We could just use full-attention for the whole network for now or go full sparse?
    • If its not a VQ-VAE, which GANs work well with discrete latent values?
    • If its VQ-VAE, its some kind of Hierarchical one. Does DALL-E model the first latent value and the rest is just randomly sampled during reconstructions?
    opened by adrian-spataru 12
  • Using pre-trained text-to-text language model in place of transformer part of the DALLE model.

    Using pre-trained text-to-text language model in place of transformer part of the DALLE model.

    Thank you for making this repository to imitate OpenAI DALLE. I was thinking that it would be efficient if we just used pre-existing sota text-to-text language model for the transformer part of the DALLE, and have a very deep feed-forward layer as adapter between the vae and transformer, and then only train that adapter layer, with the vae and transformer being frozen during training. Like using GPT-J/Reformer(long sequence lengths would help in generating high res imgs)/Deberta as transformer( although Deberta would also require changing script to support masked language modelling). This will be (in theory) way less compute expensive. What are your thoughts about it?

    opened by Vbansal21 0
  • New error using the new update.

    New error using the new update.

    [2021-09-13 11:39:11,114] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed info: version=0.5.1, git-hash=unknown, git-branch=unknown [2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed groups [2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed model parallel group with size 1 [2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed expert parallel group with size 1 [2021-09-13 11:39:11,217] [INFO] [logging.py:68:log_dist] [Rank 0] creating expert data parallel process group with ranks: [0] [2021-09-13 11:39:11,217] [INFO] [logging.py:68:log_dist] [Rank 0] creating expert parallel process group with ranks: [0] [2021-09-13 11:39:11,240] [INFO] [engine.py:198:init] DeepSpeed Flops Profiler Enabled: False Traceback (most recent call last): File "train_dalle.py", line 497, in config_params=deepspeed_config, File "/home/valterjordan/DALLE-pytorch/dalle_pytorch/distributed_backends/distributed_backend.py", line 152, in distribute **kwargs, File "/home/valterjordan/DALLE-pytorch/dalle_pytorch/distributed_backends/deepspeed_backend.py", line 162, in _distribute **kwargs, File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/init.py", line 141, in initialize config_params=config_params) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 204, in init self.training_dataloader = self.deepspeed_io(training_data) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1188, in deepspeed_io data_parallel_rank=data_parallel_rank) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/dataloader.py", line 52, in init rank=data_parallel_rank) File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/torch/utils/data/distributed.py", line 87, in init self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore TypeError: object of type 'Processor' has no len()

    opened by jaoeded 2
  • Running in Colab: AttributeError: module 'keras.backend' has no attribute 'is_tensor'

    Running in Colab: AttributeError: module 'keras.backend' has no attribute 'is_tensor'

    Hi, first of all thanks for the amazing job and repo!!! Triying to run the colab notebook I'm getting the following error on the last cell and no image is generated.

    AttributeError: module 'keras.backend' has no attribute 'is_tensor'

    Thanks in advanced for your help

    opened by guillermoAlv 4
  • Generations taken from Train Set (CUB200 dataset)

    Generations taken from Train Set (CUB200 dataset)

    Hi everyone, thank you for the amazing job on this repository. I'm trying to replicate @kobiso work on CUB200 dataset (#131) on a 2 GPU architecture. I'm not directly using your script "train_dalle.py" but a very similar one. The main difference is that I'm using data parallel to train DALLE in a multi-gpu environment.

    Here are some details of my implementation:

    # Step 1: load pretrained VQGAN. SET PARAMETERS
    vae = VQGanVAE()
    
    # Define Parameters
    DIM = 256
    TEXT_SEQ_LEN = 80
    DEPTH = 8
    HEADS = 8
    DIM_HEAD = 64
    ATTN_TYPES = ('full', 'axial_row', 'axial_col', 'conv_like')
    
    # Training Parameters
    LEARNING_RATE = 0.0006
    EPOCHS = 500                      # but model automatically stops after 300 epochs
    BATCH_SIZE = 64          
    
    # Adamw Optimizer. 
    BETAS = (0.9, 0.96)
    EPS = 1e-08
    WEIGHT_DECAY = 4.5e-2
    
    # gradient clipping
    MAX_NORM = 0.5  
    
    # Scheduler
    FACTOR = 0.5               
    PATIENCE = 5                
    COOLDOWN = 10               
    MIN_LEARNING_RATE = LEARNING_RATE * (FACTOR**6)   # allow lr reduction 6 times
    
    # Step 2: Dataset and tokenizer
    IMAGE_SIZE = vae.image_size
    
    ds = TextImageDataset(
            folder=dataset_path,
            text_len=TEXT_SEQ_LEN,
            image_size=IMAGE_SIZE,
            truncate_captions=True,
            tokenizer=tokenizer,        # default tokenizer
            shuffle=True,
        )
    
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    
    # Step 3: define DALLE
    dalle = DALLE(
        vae=vae,  # pretrained
        num_text_tokens=tokenizer.vocab_size,
        text_seq_len=TEXT_SEQ_LEN,
        dim=DIM,
        depth=DEPTH,
        heads=HEADS,
        dim_head=DIM_HEAD,
        attn_types=ATTN_TYPES
    ).cuda()
    
    # data parallel on 2 gpu
    dalle = torch.nn.DataParallel(dalle, device_ids=gpu_cores)
    
    
    # optimizer:
    def group_weight(model):
        group_decay, group_no_decay = [], []
        for params in model.named_parameters():
            if 'transformer' in params[0]:
                if 'bias' in params[0] or 'norm' in params[0]:
                    group_no_decay.append(params[1])
                    continue
            group_decay.append(params[1])
    
        assert len(list(model.parameters())) == len(group_decay) + len(group_no_decay)
        groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
        return groups
    
    opt = torch.optim.AdamW(group_weight(dalle), lr=LEARNING_RATE, betas=BETAS, eps=EPS, weight_decay=WEIGHT_DECAY)
    
    # ReduceLRonPlateau
    scheduler = ReduceLROnPlateau(opt, mode='min', factor=FACTOR, patience=PATIENCE, cooldown=COOLDOWN,
                                                         min_lr=MIN_LEARNING_RATE, verbose=True)
    
    

    This is the loss-plot obtained during training. Learning rate reductions are clearly visible.

    W B Chart 27_8_2021, 11_30_03

    Briefly speaking, my problem is: When I try to generate images from new captions, in some cases everything looks fine (in the following image I'm showing best 3/64 according to CLIP - the generations aren't as decent as @kobiso 's, but acceptable):

    0

    In other cases, the images are direct reproductions of training set images! For example, the first image of the following generation is taken from file Clay_Colored_Sparrow_0104_110699.jpg 24

    Clay_Colored_Sparrow_0104_110699.jpg : Clay_Colored_Sparrow_0104_110699

    Where some captions of this particular image are quite similar to the one I've passed as an input:

    the bird has a small bill that is orange and black. the wings are tri-colored, grey, russet, and brown, with similar markings on cheek patch and eyebrow. a small bird with a white throat and belly, with brown, black and white feathers covering its wings. this bird has a white belly and breast, with a brown superciliary and sharp bill. a small bird with a white belly and brown head and wings. a small bird with brown, black, and white speckled wings and orange beak and feet. this bird has wings that are brown and black with a white belly this colorful bird has an orange bill, orange feet, and brown and orange wings. this bird has a light brown and dark grey striped crown, an orange beak, and a light grey breast. this bird has grey coloring on its breast and belly, and brown and black strips on its wings.

    Maybe someone can help me understand if this is a normal behaviour or not, and how to debug this ?

    Thank you.

    Update I have made some tests with basic captions. The behaviour still seems to appear. My idea is that training Dall-E on a small dataset like CUB doesn't give the model a great generalization ability. @lucidrains @kobiso can you confirm that ? Here I post some examples

    0

    11

    23

    38

    opened by SerezD 0
  • Aspiring to go from VQ-VAE -> DALLE on Google Conceptual Captions Dataset

    Aspiring to go from VQ-VAE -> DALLE on Google Conceptual Captions Dataset

    Hi, I recently read this blog and was fascinated by the potential of these generative models. I am hoping to learn the fundamentals, reimplement models, and reproduce results from scratch. As a first step I found this repository to be VERY helpful. I can use the code here to replicate results as I am getting familiar with the theory. The blogs that have been useful to me so far:

    • https://ml.berkeley.edu/blog/posts/clip-art/
    • https://ml.berkeley.edu/blog/posts/vq-vae/
    • https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
    • https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html

    To make it easy for myself I am using the models from here but building my own scaffolding around the code. The repository is here. This is very much a work-in-progress(I work on it when I get time).

    At present, I am training the VQ-VAE model on ~2M Google Conceptual Concepts images on a system with 2 Titan RTXs. The training progress can be seenhere.

    I will kill this training once the images start looking good. Then will move to the DALLE part. The real fun(pain) will start then perhaps.

    I will try to keep this ticket updated with progress.

    opened by appliedml85 7
  • when generate imagas,  error : __init__() got an unexpected keyword argument 'device'

    when generate imagas, error : __init__() got an unexpected keyword argument 'device'

    env: official docker model: from https://github.com/robvanvolt/DALLE-models

    [email protected]:/workspace/dalle# python generate.py --dalle_path ./model/dalle_checkpoint.pt --text 'fireflies in a field under a full moon' --taming --vqgan_model_path ./model/vqgan.1024.model.ckpt --vqgan_config_path ./model/vqgan.1024.config.yml Working with z of shape (1, 256, 16, 16) = 65536 dimensions. loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth VQLPIPSWithDiscriminator running with hinge loss. Loaded VQGAN from ./model/vqgan.1024.model.ckpt and ./model/vqgan.1024.config.yml generating images for - fireflies in a field under a full moon: 0%| | 0/32 [00:00<?, ?it/s] 0it [00:00, ?it/s]for - fireflies in a field under a full moon: 0%| | 0/32 [00:00<?, ?it/s] Traceback (most recent call last): File "generate.py", line 116, in output = dalle.generate_images(text_chunk, filter_thres = args.top_k) File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/workspace/dalle/dalle_pytorch/dalle_pytorch.py", line 42, in inner out = fn(model, *args, **kwargs) File "/workspace/dalle/dalle_pytorch/dalle_pytorch.py", line 480, in generate_images logits = self(text, image, mask = mask)[:, -1, :] File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/dalle_pytorch.py", line 552, in forward out = self.transformer(tokens) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/transformer.py", line 142, in forward return self.layers(x, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/reversible.py", line 156, in forward out = _ReversibleFunction.apply(x, blocks, args) File "/workspace/dalle/dalle_pytorch/reversible.py", line 113, in forward x = block(x, **kwarg) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/reversible.py", line 65, in forward y1 = x1 + self.f(x2, record_rng=self.training, **f_args) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/reversible.py", line 40, in forward return self.net(*args, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/transformer.py", line 53, in forward return self.fn(x, **kwargs) * self.scale File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/transformer.py", line 62, in forward return self.fn(self.norm(x), **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/attention.py", line 362, in forward out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/sparse_self_attention.py", line 152, in forward attn_output_weights = sparse_dot_sdd_nt(query, key) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 745, in call time_db) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 549, in forward c_time) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 192, in _sdd_matmul num_warps=4) TypeError: init() got an unexpected keyword argument 'device'

    [email protected]:/workspace/dalle# python generate.py --dalle_path ./model/dalle_checkpoint.pt --text 'fireflies in a field under a full moon' --taming --vqgan_model_path ./model/vqgan.1024.model.ckpt --vqgan_config_path ./model/vqgan.1024.config.yml Working with z of shape (1, 256, 16, 16) = 65536 dimensions. loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth VQLPIPSWithDiscriminator running with hinge loss. Loaded VQGAN from ./model/vqgan.1024.model.ckpt and ./model/vqgan.1024.config.yml generating images for - fireflies in a field under a full moon: 0%| | 0/32 [00:00<?, ?it/s] 0it [00:00, ?it/s]for - fireflies in a field under a full moon: 0%| | 0/32 [00:00<?, ?it/s] Traceback (most recent call last): File "generate.py", line 116, in output = dalle.generate_images(text_chunk, filter_thres = args.top_k) File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/workspace/dalle/dalle_pytorch/dalle_pytorch.py", line 42, in inner out = fn(model, *args, **kwargs) File "/workspace/dalle/dalle_pytorch/dalle_pytorch.py", line 480, in generate_images logits = self(text, image, mask = mask)[:, -1, :] File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/dalle_pytorch.py", line 552, in forward out = self.transformer(tokens) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/transformer.py", line 142, in forward return self.layers(x, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/reversible.py", line 156, in forward out = _ReversibleFunction.apply(x, blocks, args) File "/workspace/dalle/dalle_pytorch/reversible.py", line 113, in forward x = block(x, **kwarg) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/reversible.py", line 65, in forward y1 = x1 + self.f(x2, record_rng=self.training, **f_args) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/reversible.py", line 40, in forward return self.net(*args, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/transformer.py", line 53, in forward return self.fn(x, **kwargs) * self.scale File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/transformer.py", line 62, in forward return self.fn(self.norm(x), **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/workspace/dalle/dalle_pytorch/attention.py", line 362, in forward out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/sparse_self_attention.py", line 152, in forward attn_output_weights = sparse_dot_sdd_nt(query, key) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 745, in call time_db) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 549, in forward c_time) File "/opt/conda/lib/python3.7/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 192, in _sdd_matmul num_warps=4) TypeError: init() got an unexpected keyword argument 'device'

    opened by zxy2020 0
  • Getting Text to Output a Font/Handwriting of the Same Text

    Getting Text to Output a Font/Handwriting of the Same Text

    Discussed in https://github.com/lucidrains/DALLE-pytorch/discussions/339

    photooftheflag an_illustration

    Originally posted by afiaka87 July 17, 2021 I've been training a DALL-e with the goal of seeing whether or not a caption could be used to visualize the text itself in RGB pixels. I'm limited by my GPU but early results are certainly interesting. I'm using the oft-ignored weights from OpenAI's dVAE under the assupmtion it would better represent text (because that is mentioned as an explicit goal in the DALLE paper). But these early results are promising so I'm switching back to the pretrained VQGAN from CompVis to see if it can represent letters graphically as well as the dVaE.

    I'm using the new augly framework for the text transforms. The code looks like this:

    # in dalle_pytorch/loader.py
    import augly.image as imaugs
    import augly.text as textaugs
    # ...
    self.image_transform = T.Compose([
        T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
        T.CenterCrop((192, 256)) # `memeify` takes up the top 64 pixels.
    ])
    # ...
    substring = description[:20] + "\n" + description[20:40]
    pil_image = PIL.Image.open(image_file)
    top_cut_image = self.image_transform(pil_image)
    aug_image = imaugs.meme_format(top_cut_image, text=substring, opacity=1.0, caption_height=64)
    image_tensor = T.Compose([
        #T.CenterCrop(256),
        T.ToTensor(),
    

    In this last example - it seems to be reusing codes found for text-gen in the actual image itself; which is what I was hoping for. Screenshot from 2021-07-17 04-06-26

    opened by afiaka87 0
  • 20 Epochs on COCO - (Larger Transformer)

    20 Epochs on COCO - (Larger Transformer)

    Discussed in https://github.com/lucidrains/DALLE-pytorch/discussions/335

    Originally posted by afiaka87 July 11, 2021 Full W&B training session

    media_images_image_14100_c563c7f9470a4a3dd2c2

    media_images_image_14500_d5fdc93c3d9bba882b25

    coco_trained

    Details Transformer:

    • Visual Dim - 512
    • Max Text Length/Language Dim - 80
    • Depth - 16
    • Heads - 16
    • Head Dim - 64
    • Attention: (axial_row, axial_row, axial_col, axial_row,axial_row, axial_row, axial_col, axial_row,axial_row, axial_row, axial_col, full,axial_row, axial_row, axial_col, full) # note the two layers of dense attention
    • lr_decay = True
    • Reversible = False

    Hardware:

    • 1x RTX 2070 Super (8 GiB VRAM)
    • 1x AMD Ryzen 3900
    • 32 GiB RAM

    Checkpoints: You can find a checkpoint for each epoch trained here: https://wandb.ai/dalle-pytorch-replicate/COCO512_16_16D_16H_80TSL/artifacts/model/trained-dalle/07c445559fd9183e302e

    opened by afiaka87 0
  • RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm_out

    RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm_out

    Have the error when running --amp: deepspeed train_dalle.py --image_text_folder /home/18zs11/dataset/cub/image-and-text --taming --bpe_path ./bpe.model --distr_backend deepspeed --amp; Does anyone have the same error?

    [2021-07-06 22:13:06,616] [INFO] [logging.py:68:log_dist] [Rank 0] Saving model checkpoint: cub/dalle_cub-ds-cp/global_step0/mp_rank_00_model_states.pt
    /home/18zs11/DALLE-pytorch/dalle_pytorch/tokenizer.py:263: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach$ ) or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    result[i, :len(tokens)] = torch.tensor(tokens)
    /home/18zs11/DALLE-pytorch/dalle_pytorch/tokenizer.py:263: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach$ ) or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    result[i, :len(tokens)] = torch.tensor(tokens)
    Traceback (most recent call last):
    File "train_dalle.py", line 562, in
    loss = distr_dalle(text, images, return_loss=True)
    File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs) File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1102, in forward loss = self.module(*inputs, **kwargs)

    File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/18zs11/DALLE-pytorch/dalle_pytorch/dalle_pytorch.py", line 485, in forward image = self.vae.get_codebook_indices(image) File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(*args, **kwargs) File "/home/18zs11/DALLE-pytorch/dalle_pytorch/vae.py", line 196, in get_codebook_indices _, _, [_, _, indices] = self.model.encode(img) File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/taming/models/vqgan.py", line 58, in encode quant, emb_loss, info = self.quantize(h) File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/taming/modules/vqvae/quantize.py", line 282, in forward torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) File "/home/18zs11/anaconda3/envs/clip/lib/python3.7/site-packages/torch/functional.py", line 344, in einsum return _VF.einsum(equation, operands) # type: ignore RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm_out

    opened by Gitsamshi 3
Releases(1.0.7)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Open-AI's DALL-E for large scale training in mesh-tensorflow.

DALL-E in Mesh-Tensorflow [WIP] Open-AI's DALL-E in Mesh-Tensorflow. If this is similarly efficient to GPT-Neo, this repo should be able to train mode

EleutherAI 364 Sep 10, 2021
Implementation of the paper "Language-agnostic representation learning of source code from structure and context".

Code Transformer This is an official PyTorch implementation of the CodeTransformer model proposed in: D. Zügner, T. Kirschstein, M. Catasta, J. Leskov

Daniel Zügner 66 Sep 15, 2021
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 8.4k Sep 22, 2021
A simple but complete full-attention transformer with a set of promising experimental features from various papers

x-transformers A concise but fully-featured transformer, complete with a set of promising experimental features from various papers. Install $ pip ins

Phil Wang 1.1k Sep 25, 2021
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 5.7k Sep 17, 2021
Learning and Building Convolutional Neural Networks using PyTorch

Image Classification Using Deep Learning Learning and Building Convolutional Neural Networks using PyTorch. Models, selected are based on number of ci

Mayur 39 Sep 22, 2021
Collection of generative models in Pytorch version.

pytorch-generative-model-collections Original : [Tensorflow version] Pytorch implementation of various GANs. This repository was re-implemented with r

Hyeonwoo Kang 2.3k Sep 16, 2021
🐥A PyTorch implementation of OpenAI's finetuned transformer language model with a script to import the weights pre-trained by OpenAI

PyTorch implementation of OpenAI's Finetuned Transformer Language Model This is a PyTorch implementation of the TensorFlow code provided with OpenAI's

Hugging Face 1.3k Sep 17, 2021
PyTorch package for the discrete VAE used for DALL·E.

Overview [Blog] [Paper] [Model Card] [Usage] This is the official PyTorch package for the discrete VAE used for DALL·E. Installation Before running th

OpenAI 3.3k Sep 19, 2021
This is an official implementation for "Video Swin Transformers".

Video Swin Transformer By Ze Liu*, Jia Ning*, Yue Cao, Yixuan Wei, Zheng Zhang, Stephen Lin and Han Hu. This repo is the official implementation of "V

Swin Transformer 348 Sep 24, 2021
Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Swin-Transformer-Tensorflow A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Sh

null 26 Sep 19, 2021
PaddleViT: State-of-the-art Visual Transformer and MLP Models for PaddlePaddle 2.0+

PaddlePaddle Vision Transformers State-of-the-art Visual Transformer and MLP Models for PaddlePaddle ?? PaddlePaddle Visual Transformers (PaddleViT or

null 87 Sep 23, 2021
This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

null 13 Sep 24, 2021
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

null 51k Sep 24, 2021
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

null 46.1k Feb 13, 2021
🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

?? Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

xmu-xiaoma66 1.8k Sep 19, 2021
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 10.2k Sep 22, 2021
PyTorch code of my ICDAR 2021 paper Vision Transformer for Fast and Efficient Scene Text Recognition (ViTSTR)

Vision Transformer for Fast and Efficient Scene Text Recognition (ICDAR 2021) ViTSTR is a simple single-stage model that uses a pre-trained Vision Tra

Rowel Atienza 74 Sep 20, 2021
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

Microsoft 172 Sep 16, 2021