Official codebase used to develop Vision Transformer, MLP-Mixer, LiT and more.

Overview

Big Vision

This codebase is designed for training large-scale vision models on Cloud TPU VMs. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow Datasets for scalable input pipelines in the Cloud.

The open-sourcing of this codebase has two main purposes:

  1. Publishing the code of research projects developed in this codebase (see a list below).
  2. Providing a strong starting point for running large-scale vision experiments on Google Cloud TPUs, which should scale seamlessly and out-of-the box from a single TPU core to a distributed setup with up to 2048 TPU cores.

Note, that despite being TPU-centric, our codebase should in general support CPU, GPU and single-host multi-GPU training, thanks to JAX' well-executed and transparent support for multiple backends.

big_vision aims to support research projects at Google. We are unlikely to work on feature requests or accept external contributions, unless they were pre-approved (ask in an issue first). For a well-supported transfer-only codebase, see also vision_transformer.

The following research projects were originally conducted in the big_vision codebase:

Architecture research

Multimodal research

Knowledge distillation

Misc

  • Are we done with ImageNet?, by Lucas Beyer*, Olivier J. Hénaff*, Alexander Kolesnikov*, Xiaohua Zhai*, and Aäron van den Oord*

Codebase high-level organization and principles in a nutshell

The main entry point is a trainer module, which typically does all the boilerplate related to creating a model and an optimizer, loading the data, checkpointing and training/evaluating the model inside a loop. We provide the canonical trainer train.py in the root folder. Normally, individual projects within big_vision fork and customize this trainer.

All models, evaluators and preprocessing operations live in the corresponding subdirectories and can often be reused between different projects. We encourage compatible APIs within these directories to facilitate reusability, but it is not strictly enforced, as individual projects may need to introduce their custom APIs.

We have a powerful configuration system, with the configs living in the configs/ directory. Custom trainers and modules can seamlessly extend/modify the configuration options.

Training jobs are robust to interruptions and will resume seamlessly from the last saved checkpoint (assuming user provides the correct --workdir path).

Each configuration file contains a comment at the top with a COMMAND snippet to run it, and some hint of expected runtime and results. See below for more details, but generally speaking, running on a GPU machine involves calling python -m COMMAND while running on TPUs, including multi-host, involves

gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all
  --command "bash big_vision/run_tpu.sh COMMAND"

See instructions below for more details on how to use Google Cloud TPUs.

Current and future contents

The first release contains the core part of pre-training, transferring, and evaluating classification models at scale on Cloud TPU VMs.

Features and projects we plan to release in the near future, in no particular order:

  • ImageNet-21k in TFDS.
  • MLP-Mixer.
  • Loading misc public models used in our publications (NFNet, MoCov3, DINO).
  • Contrastive Image-Text model training and evaluation as in LiT and CLIP.
  • "Patient and consistent" distillation.
  • Memory-efficient Polyak-averaging implementation.
  • Advanced JAX compute and memory profiling. We are using internal tools for this, but may eventually add support for the publicly available ones.

We will continue releasing code of our future publications developed within big_vision here.

Non-content

The following exist in the internal variant of this codebase, and there is no plan for their release:

  • Regular regression tests for both quality and speed. They rely heavily on internal infrastructure.
  • Advanced logging, monitoring, and plotting of experiments. This also relies heavily on internal infrastructure. However, we are open to ideas on this and may add some in the future, especially if implemented in a self-contained manner.
  • Not yet published, ongoing research projects.

Running on Cloud TPU VMs

Create TPU VMs

To create a single machine with 8 TPU cores, follow the following Cloud TPU JAX document: https://cloud.google.com/tpu/docs/run-calculation-jax

To support large-scale vision research, more cores with multiple hosts are recommended. Below we provide instructions on how to do it.

First, create some useful variables, which we be reused:

export NAME="a name of the TPU deployment, e.g. my-tpu-machine"
export ZONE="GCP geographical zone, e.g. europe-west4-a"
export GS_BUCKET_NAME="Name of the storage bucket, e.g. my_bucket"

The following command line will create TPU VMs with 32 cores, 4 hosts.

gcloud alpha compute tpus tpu-vm create $NAME --zone $ZONE --accelerator-type v3-32 --version tpu-vm-tf-2.8.0

Install big_vision on TPU VMs

Fetch the big_vision repository, copy it to all TPU VM hosts, and install dependencies.

git clone --branch=master https://github.com/google-research/big_vision
gcloud alpha compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --worker=all --zone=$ZONE
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh"

Download and prepare TFDS datasets

Everything in this section you need to do only once, and, alternatively, you can also do it on your local machine and copy the result to the cloud bucket. For convenience, we provide instructions on how to prepare data using Cloud TPUs.

Download and prepare TFDS datasets using a single worker. Seven TFDS datasets used during evaluations will be generated under ~/tensorflow_datasets/ (should take 10-15 minutes in total).

gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced"

Copy the datasets to GS bucket, to make them accessible to all TPU workers.

gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "rm -r ~/tensorflow_datasets/downloads && gsutil cp -r ~/tensorflow_datasets gs://$GS_BUCKET_NAME"

If you want to integrate other public or custom datasets, i.e. imagenet2012, please follow the official guideline.

Pre-trained models

For the full list of pre-trained models check out the load function defined in the same module as the model code. And for example config on how to use these models, see configs/transfer.py.

Run the transfer script on TPU VMs

The following command line fine-tunes a pre-trained vit-i21k-augreg-b/32 model on cifar10 dataset.

gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"

Run the train script on TPU VMs

To train your own big_vision models on a large dataset, e.g. imagenet2012 (prepare the TFDS dataset), run the following command line.

gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py  --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"

ViT baseline

We provide a well-tuned ViT-S/16 baseline in the config file named vit_s16_i1k.py. It achieves 76.5% accuracy on ImageNet validation split in 90 epochs of training, being a strong and simple starting point for research on the ViT models.

Please see our arXiv note for more details and if this baseline happens to by useful for your research, consider citing

@article{vit_baseline,
  url = {https://arxiv.org/abs/2205.01580},
  author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
  title = {Better plain ViT baselines for ImageNet-1k},
  journal={arXiv preprint arXiv:2205.01580},
  year = {2022},
}

Citing the codebase

If you found this codebase useful for your research, please consider using the following BibTEX to cite it:

@misc{big_vision,
  author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
  title = {Big Vision},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/google-research/big_vision}}
}

Disclaimer

This is not an official Google Product.

Comments
  • implement gsam in jax

    implement gsam in jax

    Hi, @lucasb-eyer thanks for your review and comments. I reformated the files and squashed commits into a new PR (sorry I messed up the old PR and could not squash commits there). This PR includes:

    1. Put GSAM related configs into config.gsam and call gsam with l, grads = gsam_gradient(loss_fn=loss_fn, base_opt=opt, inputs=images, targets=labels, lr=learning_rate, **config["gsam"])
    2. Add big_vision/configs/proj/gsam/vit_1k_gsam_no_aug.py, the network used in GSAM paper used pool_type='gap' and rep_size=False, which is different from the default config.
    3. Fix format issues and squash commits.

    Regarding reproducing the experiments, I wonder if it's possible for you to run the script (with 8x8 TPU cores to exactly match the paper)? I'm sorry I don't have access to TPU resources since I'm not affiliated with Google now, so I can't run experiments, though the checkpoints and the old version code that I used were kept in server. Thanks so much for your code review and help!

    opened by juntang-zhuang 22
  • Any extra dataset prep needed?

    Any extra dataset prep needed?

    I have followed the instructions from README. I have set up a TPU v3-8 machine which can be confirmed below:

    image

    I have hosted the ImageNet-1k (imagenet2012) in a separate bucket and it's structured like the below (following instructions from here):

    Screenshot 2022-05-10 at 4 03 31 PM

    While launching training, I am using the following command:

    gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://imagenet-1k/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/vit_s16_i1k.py  --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
    

    It results into the following:

    SSH key found in project metadata; not updating instance.
    SSH: Attempting to connect to worker 0...
    2022-05-10 10:30:25.858388: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
    2022-05-10 10:30:27.319919: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
    2022-05-10 10:30:27.319952: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
    I0510 10:30:27.335715 140289404775488 xla_bridge.py:263] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
    I0510 10:30:27.336199 140289404775488 xla_bridge.py:263] Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter TPU Host
    I0510 10:30:30.058175 140289404775488 train.py:65] Hello from process 0 holding 8/8 devices and writing to workdir gs://big_vision_exp/big_vision/workdir/05-10_1030.
    I0510 10:30:30.568850 140289404775488 train.py:95] NOTE: Global batch size 1024 on 1 hosts results in 1024 local batch size. With 8 dev per host (8 dev total), that's a 128 per-device batch size.
    I0510 10:30:30.570343 140289404775488 train.py:95] NOTE: Initializing train dataset...
    I0510 10:30:31.039579 140289404775488 dataset_info.py:522] Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: imagenet2012/5.1.0
    I0510 10:30:31.303886 140289404775488 dataset_info.py:439] Load dataset info from /tmp/tmpggpl8znitfds
    I0510 10:30:31.308489 140289404775488 dataset_info.py:492] Field info.description from disk and from code do not match. Keeping the one from code.
    I0510 10:30:31.308714 140289404775488 dataset_info.py:492] Field info.release_notes from disk and from code do not match. Keeping the one from code.
    I0510 10:30:31.308900 140289404775488 dataset_info.py:492] Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
    I0510 10:30:31.308959 140289404775488 dataset_info.py:492] Field info.module_name from disk and from code do not match. Keeping the one from code.
    I0510 10:30:31.309248 140289404775488 logging_logger.py:44] Constructing tf.data.Dataset imagenet2012 for split _EvenSplit(split='train[:99%]', index=0, count=1, drop_remainder=False), from gs://imagenet-1k/tensorflow_datasets/imagenet2012/5.1.0
    Traceback (most recent call last):
      File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/spsayakpaul/big_vision/train.py", line 372, in <module>
        app.run(main)
      File "/home/spsayakpaul/bv_venv/lib/python3.8/site-packages/absl/app.py", line 312, in run
        _run_main(main, args)
      File "/home/spsayakpaul/bv_venv/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
        sys.exit(main(argv))
      File "/home/spsayakpaul/big_vision/train.py", line 122, in main
        train_ds = input_pipeline.make_for_train(
      File "/home/spsayakpaul/big_vision/input_pipeline.py", line 69, in make_for_train
        data, _ = get_dataset_tfds(dataset=dataset, split=split,
      File "/home/spsayakpaul/big_vision/input_pipeline.py", line 53, in get_dataset_tfds
        return builder.as_dataset(
      File "/home/spsayakpaul/bv_venv/lib/python3.8/site-packages/tensorflow_datasets/core/logging/__init__.py", line 81, in decorator
        return function(*args, **kwargs)
      File "/home/spsayakpaul/bv_venv/lib/python3.8/site-packages/tensorflow_datasets/core/dataset_builder.py", line 565, in as_dataset
        raise AssertionError(
    AssertionError: Dataset imagenet2012: could not find data in gs://imagenet-1k/tensorflow_datasets. Please make sure to call dataset_builder.download_and_prepare(), or pass download=True to tfds.load() before trying to access the tf.data.Dataset object.
    

    Is there anything I'm missing out here?

    opened by sayakpaul 10
  • implement_gsam_jax

    implement_gsam_jax

    opened by juntang-zhuang 5
  • Adds

    Adds "LiT: Zero-Shot Transfer with Locked-image text Tuning".

    opened by andsteing 2
  • augmentation and regularization used in MLP-Mixer

    augmentation and regularization used in MLP-Mixer

    Hello! Thank you for your work! In the paper of MLP-Mixer, when training mixer-B/16 on imagenet1k from scratch, it is said extra regularization is applied to gain the the accuracy of 76% , I want to know what detailed augmentation and regularization strategy is used for the experiment? Is there any config file can be found? Thank you for your help! : )

    opened by Ga-Lee 2
  • TPU utilization could be improved further?

    TPU utilization could be improved further?

    Training details are in https://github.com/google-research/big_vision/issues/2

    I think the TPU utilization is a bit lower than expected:

    Screenshot 2022-05-11 at 7 42 31 PM

    Is this expected?

    I understand there might be other network access factors that can contribute to this but wanted to know.

    opened by sayakpaul 2
  • Accuracy of vit-b-16 training

    Accuracy of vit-b-16 training

    Hi, May I ask the top-1 accuracy of vit-b-16 training on imagenet-1k based on the config file "vit_1ik.py". I find the related paper report the accuracy is about 74.6.

    Thank you very much!

    Best Lucas

    opened by lucasliunju 1
  • Reduce peak memory usage when freezing parameters.

    Reduce peak memory usage when freezing parameters.

    I discovered a optax.set_to_zero() from this thread.

    When compare with the original optax.scale(0.0) on a ViT H/16 with some heads, peak GPU memory usage (by setting XLA_PYTHON_CLIENT_PREALLOCATE=false):

    • Full trainable: 18GiB
    • optax.scale(0.0) (current): 9.8GiB
    • optax.set_to_zero (PR): 5.6GiB

    The frozen weight was set in the config like this (for both current and PR change):

      config.schedule = [
        (".*ViT_0/.*", None),
        (".*", dict(warmup_steps=2500))
      ]
    

    Theoretically memory usage should be the same after jitted, so I'm not sure if this is a GPU-specific bug from jax or not.

    opened by lkhphuc 0
  • Could you provide the checkpoint of the CLIPPO model?

    Could you provide the checkpoint of the CLIPPO model?

    I noticed that you have provided the CLIPPO training code. I hope to explore some downstream task based on the pre-trained CLIPPO model. Could you please release the checkpoint?

    Thank you!

    opened by zzhanghub 0
  • Colorization uvim model not working

    Colorization uvim model not working

    Hi @andresusanopinto I hope you are well. I tried using the colorization model on my images but out of 4 it colorized only 1 image and the result is also not good at all. Can you please tell me what I am missing ? image

    opened by muhammad-ahmed-ghani 5
  • AttributeError pp_img in lit notebook

    AttributeError pp_img in lit notebook

    Hi. An AttributeError is raised when running big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb notebook in colab: image

    P.S: raised here config.pp_img. P.S.S: here also will be AttributeError: config.pp_txt

    opened by amrzv 2
  • Add pylint action with Google style configuration.

    Add pylint action with Google style configuration.

    This would run pylint with the official Google style configuration on every PR automatically, saving us quite a bit of time.

    We can already see results of it in this PR

    However, it seems that maybe we should submit the configuration here and fine-tune it a bit? I'm seeing several lint errors that we don't actually hit internally, for example: image So we should see if we like this difference, and/or where it comes from, before merging.

    opened by lucasb-eyer 0
Owner
Google Research
Google Research
PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021)

mlp-mixer-pytorch PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021) Usage import torch from mlp_mixer

isaac 27 Jul 9, 2022
Unofficial implementation of MLP-Mixer: An all-MLP Architecture for Vision

MLP-Mixer: An all-MLP Architecture for Vision This repo contains PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision. Usage : impo

Rishikesh (ऋषिकेश) 175 Dec 23, 2022
Implements MLP-Mixer: An all-MLP Architecture for Vision.

MLP-Mixer-CIFAR10 This repository implements MLP-Mixer as proposed in MLP-Mixer: An all-MLP Architecture for Vision. The paper introduces an all MLP (

Sayak Paul 51 Jan 4, 2023
Implementation for paper MLP-Mixer: An all-MLP Architecture for Vision

MLP Mixer Implementation for paper MLP-Mixer: An all-MLP Architecture for Vision. Give us a star if you like this repo. Author: Github: bangoc123 Emai

Ngoc Nguyen Ba 86 Dec 10, 2022
Vision Transformer and MLP-Mixer Architectures

Vision Transformer and MLP-Mixer Architectures Update (2.7.2021): Added the "When Vision Transformers Outperform ResNets..." paper, and SAM (Sharpness

Google Research 6.4k Jan 4, 2023
This is an official implementation for "AS-MLP: An Axial Shifted MLP Architecture for Vision".

AS-MLP architecture for Image Classification Model Zoo Image Classification on ImageNet-1K Network Resolution Top-1 (%) Params FLOPs Throughput (image

SVIP Lab 106 Dec 12, 2022
PyTorch implementation of MLP-Mixer

PyTorch implementation of MLP-Mixer MLP-Mixer: an all-MLP architecture composed of alternate token-mixing and channel-mixing operations. The token-mix

Duo Li 33 Nov 27, 2022
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Rishabh Anand 24 Mar 23, 2022
Keras attention models including botnet,CoaT,CoAtNet,CMT,cotnet,halonet,resnest,resnext,resnetd,volo,mlp-mixer,resmlp,gmlp,levit

Keras_cv_attention_models Keras_cv_attention_models Usage Basic Usage Layers Model surgery AotNet ResNetD ResNeXt ResNetQ BotNet VOLO ResNeSt HaloNet

null 319 Dec 28, 2022
Unofficial Implementation of MLP-Mixer, Image Classification Model

MLP-Mixer Unoffical Implementation of MLP-Mixer, easy to use with terminal. Train and test easly. https://arxiv.org/abs/2105.01601 MLP-Mixer is an arc

Oğuzhan Ercan 6 Dec 5, 2022
Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

Qiushi Yang 2 Sep 29, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
Official codebase for ICLR oral paper Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling

CLIORA This is the official codebase for ICLR oral paper: Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling. We introduce

Bo Wan                                             32 Dec 23, 2022
An All-MLP solution for Vision, from Google AI

MLP Mixer - Pytorch An All-MLP solution for Vision, from Google AI, in Pytorch. No convolutions nor attention needed! Yannic Kilcher video Install $ p

Phil Wang 784 Jan 6, 2023
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Qibin (Andrew) Hou 162 Nov 28, 2022
PaddleViT: State-of-the-art Visual Transformer and MLP Models for PaddlePaddle 2.0+

PaddlePaddle Vision Transformers State-of-the-art Visual Transformer and MLP Models for PaddlePaddle ?? PaddlePaddle Visual Transformers (PaddleViT or

null 1k Dec 28, 2022
Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.

Decision Transformer Lili Chen*, Kevin Lu*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor M

Kevin Lu 1.4k Jan 7, 2023