Open-AI's DALL-E for large scale training in mesh-tensorflow.

Overview

DALL-E in Mesh-Tensorflow [WIP]

Open-AI's DALL-E in Mesh-Tensorflow.

If this is similarly efficient to GPT-Neo, this repo should be able to train models up to, and larger than, the size of Open-AI's DALL-E (12B params).

No pretrained models... Yet.

Thanks to Ben Wang for the tf vae implementation as well as getting the mtf version working, and Aran Komatsuzaki for help building the mtf VAE and input pipeline.

Setup

git clone https://github.com/EleutherAI/GPTNeo
cd GPTNeo
pip3 install -r requirements.txt

Training Setup

Runs on TPUs, untested on GPUs but should work in theory. The example configs are designed to run on a TPU v3-32 pod.

To set up TPUs, sign up for Google Cloud Platform, and create a storage bucket.

Create your VM through a google shell (https://ssh.cloud.google.com/) with ctpu up --vm-only so that it can connect to your Google bucket and TPUs and setup the repo as above.

VAE pretraining

DALLE needs a pretrained VAE to compress images to tokens. To run the VAE pretraining, adjust the params in configs/vae_example.json to a glob path pointing to a dataset of jpgs, and adjust image size to the appropriate size.

  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg",
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg",
    "image_size": 32
  }

Once this is all set up, create your TPU, then run:

python train_vae_tf.py --tpu your_tpu_name --model vae_example

The training logs image tensors and loss values, to check progress, you can run:

tensorboard --logdir your_model_dir

Dataset Creation [DALL-E]

Once the VAE is pretrained, you can move on to DALL-E.

Currently we are training on a dummy dataset. A public, large-scale dataset for DALL-E is in the works. In the meantime, to generate some dummy data, run:

python src/data/create_tfrecords.py

This should download CIFAR-10, and generate some random captions to act as text inputs.

Custom datasets should be formatted in a folder, with a jsonl file in the root folder containing caption data and paths to the respective images, as follows:

Folder structure:

        data_folder
            jsonl_file
            folder_1
                img1
                img2
                ...
            folder_2
                img1
                img2
                ...
            ...

jsonl structure:
    {"image_path": folder_1/img1, "caption": "some words"}
    {"image_path": folder_2/img2, "caption": "more words"}
    ...

you can then use the create_paired_dataset function in src/data/create_tfrecords.py to encode the dataset into tfrecords for use in training.

Once the dataset is created, copy it over to your bucket with gsutil:

gsutil cp -r DALLE-tfrecords gs://neo-datasets/

And finally, run training with

python train_dalle.py --tpu your_tpu_name --model dalle_example

Config Guide

VAE:

{
  "model_type": "vae",
  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg", # glob path to training images
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg", # glob path to eval images
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, 
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000, # how often to save a checkpoint
  "iterations": 500, # number of batches to infeed to the tpu at a time. Must be < steps_per_checkpoint
  "train_steps": 100000, # total training steps
  "eval_steps": 0, # run evaluation for this many steps every steps_per_checkpoint
  "model_path": "gs://neo-models/vae_test2/", # directory in which to save the model
  "mesh_shape": "data:16,model:2", # mapping of processors to named dimensions - see mesh-tensorflow repo for more info
  "layout": "batch_dim:data", # which named dimensions of the model to split across the mesh - see mesh-tensorflow repo for more info
  "num_tokens": 512, # vocab size
  "dim": 512, 
  "hidden_dim": 64, # size of hidden dim
  "n_channels": 3, # number of input channels
  "bf_16": false, # if true, the model is trained with bfloat16 precision
  "lr": 0.001, # learning rate [by default learning rate starts at this value, then decays to 10% of this value over the course of the training]
  "num_layers": 3, # number of blocks in the encoder / decoder
  "train_gumbel_hard": true, # whether to use hard or soft gumbel_softmax
  "eval_gumbel_hard": true
}

DALL-E:

{
  "model_type": "dalle",
  "dataset": {
    "train_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords", # glob path to tfrecords data
    "eval_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords",
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, # see above
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000,
  "iterations": 500,
  "train_steps": 100000,
  "predict_steps": 0,
  "eval_steps": 0,
  "n_channels": 3,
  "bf_16": false,
  "lr": 0.001,
  "model_path": "gs://neo-models/dalle_test/",
  "mesh_shape": "data:16,model:2",
  "layout": "batch_dim:data",
  "n_embd": 512, # size of embedding dim
  "text_vocab_size": 50258, # vocabulary size of the text tokenizer
  "image_vocab_size": 512, # vocabulary size of the vae - should equal num_tokens above
  "text_seq_len": 256, # length of text inputs (all inputs longer / shorter will be truncated / padded)
  "n_layers": 6, 
  "n_heads": 4, # number of attention heads. For best performance, n_embd / n_heads should equal 128
  "vae_model": "vae_example" # path to or name of vae model config
}
Comments
  • Easy to use dataset for DALL-E

    Easy to use dataset for DALL-E

    The COCO dataset is a high-quality dataset both in terms of images and text. Each image has multiple captions and it consists of around 100 000 images. I have created a Google Colab which does the following:

    1. Fetches the COCO images and captions
    2. Allows the user to specify image dimensions which the images will be resized to
    3. Stores information in two files for easy access in the following format:

    od-captions.txt (<image_path> : <image_caption>)

    train2017/000000203564.jpg : A bicycle replica with a clock as the front wheel. train2017/000000322141.jpg : A room with blue walls and a white sink and door. train2017/00000016977.jpg : A car that seems to be parked illegally behind a legally parked car train2017/000000106140.jpg : A large passenger airplane flying through the air. train2017/000000571635.jpg : A bathroom with a toilet, sink, and shower.

    od-captionsonly.txt (<image_caption>)

    A bicycle replica with a clock as the front wheel. A room with blue walls and a white sink and door. A car that seems to be parked illegally behind a legally parked car A large passenger airplane flying through the air. A bathroom with a toilet, sink, and shower.

    Here is an example image and caption: image The man at bat readies to swing at the pitch while the umpire looks on.

    I have written this hoping it will be somewhat compatible with @htoyryla's fork of a similar project. Might be of interest to you. Feel free to use this to generate a dataset for DALL-E!

    opened by mrconter1 2
  • Anneal gumbel softmax temperature during training

    Anneal gumbel softmax temperature during training

    https://github.com/lucidrains/DALLE-pytorch/issues/10#issuecomment-757132197

    "wow! temperature feature is awesome! Gradually decreasing it from 5 to 0.05 over 5 epochs and convergence is really fast as well as results look much better!!!"

    feature request 
    opened by sdtblck 0
  • Add resblocks to VAE

    Add resblocks to VAE

    You can pretty much copy the following code with appropriate modifications. To be more specific, the following is important ingredients:

    • Each resolution consists of down/up-sampler and blocks(), consisting of a certain number of Res blocks.
    • Each Res block looks as in block() (no modification preferred)
    • For the specifics of up/down-sampler, please refer to the comments.
    • The hidden dimension of each layer depends on the current spatial scale, and the order of blocks() and up/down-sampler() matters.
    def block(self, x, dim, name=None):
        with tf.variable_scope(name):
            hidden_dim = dim // 4
            res = x
            x = self.activation(x, name="activ1")   
            x = self.conv2d(x, hidden_dim, (1, 1), (1, 1), padding="SAME", name="conv1",
                            variable_dtype=self.variable_dtype)
            x = self.activation(x, name="activ2")
            x = self.conv2d(x, hidden_dim, (3, 3), (1, 1), padding="SAME", name="conv2",
                            variable_dtype=self.variable_dtype)
            x = self.activation(x, name="activ3")
            x = self.conv2d(x, dim, (1, 1), (1, 1), padding="SAME", name="conv3",
                            variable_dtype=self.variable_dtype)
            x += res     
            return x
    
    
    def blocks(self, x, dim, num_blocks, name=None):
        with tf.variable_scope(name):
            for idx in range(num_blocks):
                x = block(x, dim, name="block"+str(idx))    
    
    
    def decoder(self, num_res, num_blocks):
        with tf.variable_scope("decoder"):
            dim = tf.shape(x)[1] # dim = c
            # num_res is the number of resolutions. e.g. if 32 -> 64 -> 128 -> 256, then = 4
            for idx in range(num_res):
                x = self.blocks(x, dim//(2 ** idx), num_blocks, name='blocks'+str(idx))
                if idx != num_res - 1: # not the last layer
                    x = upsample(x) # nearest neighbor interpolation upsampling by the factor of 2
            x = self.conv2d(x, self.channels_dim, (1, 1), (1, 1), variable_dtype=self.variable_dtype)
            return x
    
    
    def encoder(self, x, num_res, dim, num_blocks):
        with tf.variable_scope("encoder"):
            # dim is the embedding dimension of the first layer of encoder
            x = self.conv2d(x, dim, (3, 3), (1, 1), padding="SAME", name="conv_in",
                            variable_dtype=self.variable_dtype)        
            dim = tf.shape(x)[1] # dim = c
            for idx in range(num_res):
                x = self.blocks(x, dim * (2 ** idx), num_blocks, name='blocks'+str(idx))
                if idx != num_res - 1: # not the last layer
                    x = downsample(x) # average pooling with stride = 2
            return x
    
    opened by AranKomat 0
  • Got Error

    Got Error

    I got error when run the command: python3 src/data/create_tfrecords.py

    Error: Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 6.36MB/s] Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 3.31MB/s] Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 7.19MB/s] 0it [00:00, ?it/s]Traceback (most recent call last): File "src/data/create_tfrecords.py", line 184, in <module> tokenizer=None) File "src/data/create_tfrecords.py", line 158, in create_paired_dataset data = load_jsonl(path) File "src/data/create_tfrecords.py", line 32, in load_jsonl with open(input_path, 'r', encoding='utf-8') as f: FileNotFoundError: [Errno 2] No such file or directory: '/home/data/coco/coco_captions.jsonl' 0it [00:00, ?it/s]

    opened by ntaapp 1
  • "No pertained models... yet"

    The image generation takes a good amount of time because of the training.

    When the pretrained models are released, how big is the size of the pretrained model, and how long will image generation take then? And how much computing power?

    And around what time to you plan to release pretrained models?

    Best regards

    opened by robvanvolt 0
Game Agent Framework. Helping you create AIs / Bots that learn to play any game you own!

Serpent.AI - Game Agent Framework (Python) Update: Revival (May 2020) Development work has resumed on the framework with the aim of bringing it into 2

Serpent.AI 6.4k Jan 5, 2023
An experiment on the performance of homemade Q-learning AIs in Agar.io depending on their state representation and available actions

Agar.io_Q-Learning_AI An experiment on the performance of homemade Q-learning AIs in Agar.io depending on their state representation and available act

null 1 Jun 9, 2022
ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge (ManiSkill Challenge), a large-scale learning-from-demonstrations benchmark for object manipulation.

ManiSkill-Learn ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge, a large-scale learning-from-dem

Hao Su's Lab, UCSD 48 Dec 30, 2022
A repo that contains all the mesh keys needed for mesh backend, along with a code example of how to use them in python

Mesh-Keys A repo that contains all the mesh keys needed for mesh backend, along with a code example of how to use them in python Have been seeing alot

Joseph 53 Dec 13, 2022
Mesh Graphormer is a new transformer-based method for human pose and mesh reconsruction from an input image

MeshGraphormer ✨ ✨ This is our research code of Mesh Graphormer. Mesh Graphormer is a new transformer-based method for human pose and mesh reconsructi

Microsoft 251 Jan 8, 2023
CoSMA: Convolutional Semi-Regular Mesh Autoencoder. From Paper "Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes"

Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes Implementation of CoSMA: Convolutional Semi-Regular Mesh Autoencoder arXiv p

Fraunhofer SCAI 10 Oct 11, 2022
Given a 2D triangle mesh, we could randomly generate cloud points that fill in the triangle mesh

generate_cloud_points Given a 2D triangle mesh, we could randomly generate cloud points that fill in the triangle mesh. Run python disp_mesh.py Or you

Peng Yu 2 Dec 24, 2021
AI Face Mesh: This is a simple face mesh detection program based on Artificial intelligence.

AI Face Mesh: This is a simple face mesh detection program based on Artificial Intelligence which made with Python. It's able to detect 468 different

Md. Rakibul Islam 1 Jan 13, 2022
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

DALL-E in Pytorch Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the ge

Phil Wang 5k Jan 4, 2023
PyTorch package for the discrete VAE used for DALL·E.

Overview [Blog] [Paper] [Model Card] [Usage] This is the official PyTorch package for the discrete VAE used for DALL·E. Installation Before running th

OpenAI 9.5k Jan 5, 2023
RuDOLPH: One Hyper-Modal Transformer can be creative as DALL-E and smart as CLIP

[Paper] [Хабр] [Model Card] [Colab] [Kaggle] RuDOLPH ?? ?? ☃️ One Hyper-Modal Tr

Sber AI 230 Dec 31, 2022
An Efficient Training Approach for Very Large Scale Face Recognition or F²C for simplicity.

Fast Face Classification (F²C) This is the code of our paper An Efficient Training Approach for Very Large Scale Face Recognition or F²C for simplicit

null 33 Jun 27, 2021
A large-scale video dataset for the training and evaluation of 3D human pose estimation models

ASPset-510 ASPset-510 (Australian Sports Pose Dataset) is a large-scale video dataset for the training and evaluation of 3D human pose estimation mode

Aiden Nibali 36 Oct 30, 2022
A large-scale video dataset for the training and evaluation of 3D human pose estimation models

ASPset-510 (Australian Sports Pose Dataset) is a large-scale video dataset for the training and evaluation of 3D human pose estimation models. It contains 17 different amateur subjects performing 30 sports-related actions each, for a total of 510 action clips.

Aiden Nibali 25 Jun 20, 2021
Official repository for the paper, MidiBERT-Piano: Large-scale Pre-training for Symbolic Music Understanding.

MidiBERT-Piano Authors: Yi-Hui (Sophia) Chou, I-Chun (Bronwin) Chen Introduction This is the official repository for the paper, MidiBERT-Piano: Large-

null 137 Dec 15, 2022
Galileo library for large scale graph training by JD

近年来,图计算在搜索、推荐和风控等场景中获得显著的效果,但也面临超大规模异构图训练,与现有的深度学习框架Tensorflow和PyTorch结合等难题。 Galileo(伽利略)是一个图深度学习框架,具备超大规模、易使用、易扩展、高性能、双后端等优点,旨在解决超大规模图算法在工业级场景的落地难题,提

JD Galileo Team 128 Nov 29, 2022
UniLM AI - Large-scale Self-supervised Pre-training across Tasks, Languages, and Modalities

Pre-trained (foundation) models across tasks (understanding, generation and translation), languages (100+ languages), and modalities (language, image, audio, vision + language, audio + language, etc.)

Microsoft 7.6k Jan 1, 2023
Colossal-AI: A Unified Deep Learning System for Large-Scale Parallel Training

ColossalAI An integrated large-scale model training system with efficient parallelization techniques Installation PyPI pip install colossalai Install

HPC-AI Tech 7.1k Jan 3, 2023
DeepGNN is a framework for training machine learning models on large scale graph data.

DeepGNN Overview DeepGNN is a framework for training machine learning models on large scale graph data. DeepGNN contains all the necessary features in

Microsoft 45 Jan 1, 2023