Scaling Vision with Sparse Mixture of Experts

Related tags

Deep Learning vmoe
Overview

Scaling Vision with Sparse Mixture of Experts

This repository contains the code for training and fine-tuning Sparse MoE models for vision (V-MoE) on ImageNet-21k, reproducing the results presented in the paper:

We will soon provide a colab analysing one of the models that we have released, as well as "config" files to train from scratch and fine-tune checkpoints. Stay tuned.

Installation

Simply clone this repository.

The file requirements.txt contains the requirements that can be installed via PyPi. However, we recommend installing jax, flax and optax directly from GitHub, since we use some of the latest features that are not part of any release yet.

In addition, you also have to clone the Vision Transformer repository, since we use some parts of it.

If you want to use RandAugment to train models (which we recommend if you train on ImageNet-21k or ILSVRC2012 from scratch), you must also clone the Cloud TPU repository, and name it cloud_tpu.

Checkpoints

We release the checkpoints containing the weights of some models that we trained on ImageNet (either ILSVRC2012 or ImageNet-21k). All checkpoints contain an index file (with .index extension) and one or multiple data files ( with extension .data-nnnnn-of-NNNNN, called shards). In the following list, we indicate only the prefix of each checkpoint. We recommend using gsutil to obtain the full list of files, download them, etc.

  • V-MoE S/32, 8 experts on the last two odd blocks, trained from scratch on ILSVRC2012 with RandAugment: gs://vmoe_checkpoints/vmoe_s32_last2_ilsvrc2012_randaug_medium.
  • V-MoE B/16, 8 experts on every odd block, trained from scratch on ImageNet-21k with RandAugment: gs://vmoe_checkpoints/vmoe_b16_imagenet21k_randaug_strong.
    • Fine-tuned on ILSVRC2012: gs://vmoe_checkpoints/vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012

Disclaimers

This is not an officially supported Google product.

Comments
  • flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`

    flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`

    This error occurs when I tried to run cifar10 example on GPU (TITAN XP)

    Error log

    Traceback (most recent call last):
      File "vmoe/main.py", line 71, in <module>
        app.run(main)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/absl/app.py", line 303, in run
        _run_main(main, args)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
        sys.exit(main(argv))
      File "vmoe/main.py", line 64, in main
        trainer.train_and_evaluate(config=FLAGS.config, workdir=FLAGS.workdir)
      File "/mnt/cephfs/home/cascol/working-directory/MoE-in-MoE/vmoe/vmoe/train/trainer.py", line 434, in train_and_evaluate
        tree=jax.eval_shape(train_state_initialize_fn, train_state_rngs),
      File "/mnt/cephfs/home/cascol/working-directory/MoE-in-MoE/vmoe/vmoe/train/trainer.py", line 203, in initialize
        variables = model.init(rngs, inputs)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/flax/linen/module.py", line 1124, in init
        method=method, mutable=mutable, **kwargs)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/flax/linen/module.py", line 1092, in init_with_output
        {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/flax/linen/module.py", line 1061, in apply
        )(variables, *args, **kwargs, rngs=rngs)
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/flax/core/scope.py", line 706, in wrapper
        with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
      File "/mnt/cephfs/home/cascol/anaconda3/envs/vmoe/lib/python3.7/site-packages/flax/core/scope.py", line 684, in bind
        'rngs should be a dictionary mapping strings to `jax.PRNGKey`.')
    jax._src.traceback_util.UnfilteredStackTrace: flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.InvalidRngError)
    

    My Enviroment that may concern

    cudatoolkit               11.1.1               h6406543_8    conda-forge
    flax                      0.3.6                    pypi_0    pypi
    jax                       0.2.27                   pypi_0    pypi
    jaxlib                    0.1.75+cuda11.cudnn805          pypi_0    pypi
    optax                     0.1.0                    pypi_0    pypi
    pytorch                   1.8.0           py3.7_cuda11.1_cudnn8.0.5_0    pytorch
    

    I tried tracking the error and found weird bug (the image on top is the result while setting up training, the image at the bottom is what I think the result should look like)

    微信图片_20220119210347

    image

    opened by Cascol-SCUT 8
  • ValueError: Missing field step in state dict while restoring an instance of TrainState

    ValueError: Missing field step in state dict while restoring an instance of TrainState

    Hello! Can I know if I need to do any changes to the scripts before using the checkpoints for evaluation? I downloaded the vmoe_b16_imagenet21k_randaug_strong_ft_cifar10 checkpoint files (both .index and .data-00000-of-00001 ) and named them ckpt_1.index and ckpt_1.data-00000-of-00001, respectively. Also, for running the script on a single partition I changed the "num_expert_partitions" of config dict to 1. With the above changes, when I try to run the script on google colab using below command: python vmoe/main.py --workdir=./vmoe/saved_checkpoints --config=vmoe/configs/vmoe_paper/vmoe_b16_imagenet21k_randaug_strong_ft_cifar10.py (where saved_checkpoints is the directory where the checkpoint files are present) I get the following error: ValueError: Missing field step in state dict while restoring an instance of TrainState Any help is appreciated!!

    Complete error stack:

    Traceback (most recent call last):
      File "vmoe/main.py", line 77, in <module>
        app.run(main)
      File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 312, in run
        _run_main(main, args)
      File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 258, in _run_main
        sys.exit(main(argv))
      File "vmoe/main.py", line 70, in main
        trainer.train_and_evaluate(config=FLAGS.config, workdir=FLAGS.workdir)
      File "./vmoe/train/trainer.py", line 527, in train_and_evaluate
        return _train_and_evaluate(config, workdir, mesh)
      File "./vmoe/train/trainer.py", line 575, in _train_and_evaluate
        thread_pool=ThreadPool())
      File "./vmoe/train/trainer.py", line 291, in restore_or_create_train_state
        thread_pool=thread_pool)
      File "./vmoe/checkpoints/partitioned.py", line 83, in restore_checkpoint
        'index': tree if tree is not None else axis_resources,
      File "./vmoe/checkpoints/base.py", line 137, in restore_checkpoint
        return serialization.from_bytes(tree, checkpoint_contents)
      File "./vmoe/checkpoints/serialization.py", line 71, in from_bytes
        return from_state_dict(target, state_dict)
      File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 65, in from_state_dict
        return ty_from_state_dict(target, state)
      File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 128, in _restore_dict
        for key, value in xs.items()}
      File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 128, in <dictcomp>
        for key, value in xs.items()}
      File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 65, in from_state_dict
        return ty_from_state_dict(target, state)
      File "/usr/local/lib/python3.7/dist-packages/flax/struct.py", line 146, in from_state_dict
        raise ValueError(f'Missing field {name} in state dict while restoring'
    ValueError: Missing field step in state dict while restoring an instance of TrainState
    
    opened by Harsh-Sensei 2
  • [VMOE]Pre-trained models based on ImageNet-1k instead of ImageNet-21k

    [VMOE]Pre-trained models based on ImageNet-1k instead of ImageNet-21k

    The VMOE model is great. I wonder when the pre-trained config would be made available? The pre-trained models are all based on the ImageNet-21k data, which makes sense and great. But could you guys also release the pre-trained models based on ImageNet-1k? That could be very helpful for benchmark and comparisons.

    Thanks and looking forward to your responses.

    opened by firestonelib 2
  • [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

    NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

    opened by copybara-service[bot] 0
  • [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

    NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

    opened by copybara-service[bot] 0
  • [JAX] Avoid implicit references to jax._src.

    [JAX] Avoid implicit references to jax._src.

    [JAX] Avoid implicit references to jax._src.

    An upcoming change to JAX means that it will no longer export jax._src by default. It still works (for now) to import modules from jax._src and refer to those, but they will not be present in JAX's namespace except via an explicit import.

    Note that any use of jax._src is a use of a JAX-private API. Please use public APIs instead. This change mostly does not yet switch users onto public APIs.

    opened by copybara-service[bot] 0
  • Inject current learning rate to the optimizer state. Allow plotting the norm of any array in the train state.

    Inject current learning rate to the optimizer state. Allow plotting the norm of any array in the train state.

    Inject current learning rate to the optimizer state. Allow plotting the norm of any array in the train state.

    We can now plot the norm of current parameter values, the learning rate or any other array part of the optimizer state:

    config.plot_norm_train_state_patterns = [
      # norm of the kernel params in MoE layers.
      'params/.*/moe/mlp/.*/kernel',
      # norm (=value) of current learning rate.
      'opt_state/.*/hyperparameter/learning_rate',
      # norm of 1st order moments of the kernel gradients in MoE layers (i.e. Adam inner state).
      'opt_state/.*/mu/.*/moe/mlp/.*/kernel',
    ]
    
    opened by copybara-service[bot] 0
  • A question about the dissertation

    A question about the dissertation

    Hi,I want to ask a question, even if each expert sets the capacity, will the router algorithm send all the tokens to one or two experts during training?

    opened by wangning7149 1
  • How do I use the checkpoints from Google Cloud?

    How do I use the checkpoints from Google Cloud?

    This may sound trivial. But I'd like to have some idea on how use the .data and .index file from gs://vmoe_checkpoints. Does this repo provide any scripts or code to do some evaluation/inference on those models? Any help would be appreciated.

    opened by BDHU 3
Owner
Google Research
Google Research
As-ViT: Auto-scaling Vision Transformers without Training

As-ViT: Auto-scaling Vision Transformers without Training [PDF] Wuyang Chen, Wei Huang, Xianzhi Du, Xiaodan Song, Zhangyang Wang, Denny Zhou In ICLR 2

VITA 68 Sep 5, 2022
Differentiable Neural Computers, Sparse Access Memory and Sparse Differentiable Neural Computers, for Pytorch

Differentiable Neural Computers and family, for Pytorch Includes: Differentiable Neural Computers (DNC) Sparse Access Memory (SAM) Sparse Differentiab

ixaxaar 302 Dec 14, 2022
Official PyTorch implementation of MX-Font (Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Experts)

Introduction Pytorch implementation of Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Expert. | paper Song Park1

Clova AI Research 97 Dec 23, 2022
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 68 Jul 18, 2022
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 39 Oct 5, 2021
This package implements THOR: Transformer with Stochastic Experts.

THOR: Transformer with Stochastic Experts This PyTorch package implements Taming Sparsely Activated Transformer with Stochastic Experts. Installation

Microsoft 45 Nov 22, 2022
This repository holds the code for the paper "Deep Conditional Gaussian Mixture Model forConstrained Clustering".

Deep Conditional Gaussian Mixture Model for Constrained Clustering. This repository holds the code for the paper Deep Conditional Gaussian Mixture Mod

null 17 Oct 30, 2022
SMD-Nets: Stereo Mixture Density Networks

SMD-Nets: Stereo Mixture Density Networks This repository contains a Pytorch implementation of "SMD-Nets: Stereo Mixture Density Networks" (CVPR 2021)

Fabio Tosi 115 Dec 26, 2022
Audio Source Separation is the process of separating a mixture into isolated sounds from individual sources

Audio Source Separation is the process of separating a mixture into isolated sounds from individual sources (e.g. just the lead vocals).

Victor Basu 14 Nov 7, 2022
[ICLR 2022] Pretraining Text Encoders with Adversarial Mixture of Training Signal Generators

AMOS This repository contains the scripts for fine-tuning AMOS pretrained models on GLUE and SQuAD 2.0 benchmarks. Paper: Pretraining Text Encoders wi

Microsoft 22 Sep 15, 2022
Official code for paper "Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight"

Demysitifing Local Vision Transformer, arxiv This is the official PyTorch implementation of our paper. We simply replace local self attention by (dyna

null 138 Dec 28, 2022
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Phil Wang 189 Nov 22, 2022
Image-Scaling Attacks and Defenses

Image-Scaling Attacks & Defenses This repository belongs to our publication: Erwin Quiring, David Klein, Daniel Arp, Martin Johns and Konrad Rieck. Ad

Erwin Quiring 163 Nov 21, 2022
A PyTorch implementation of " EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks."

EfficientNet A PyTorch implementation of EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. [arxiv] [Official TF Repo] Implemen

AhnDW 298 Dec 10, 2022
Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Aviv Gabbay 41 Nov 29, 2022
For auto aligning, cropping, and scaling HR and LR images for training image based neural networks

ImgAlign For auto aligning, cropping, and scaling HR and LR images for training image based neural networks Usage Make sure OpenCV is installed, 'pip

null 15 Dec 4, 2022
Official code for On Path Integration of Grid Cells: Group Representation and Isotropic Scaling (NeurIPS 2021)

On Path Integration of Grid Cells: Group Representation and Isotropic Scaling This repo contains the official implementation for the paper On Path Int

Ruiqi Gao 39 Nov 10, 2022
Implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork.

YOLOv4-large This is the implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork. YOLOv4-CSP YOLOv4-tiny YOLOv4-

Kin-Yiu, Wong 2k Jan 2, 2023
Unofficial PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution

PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution [arXiv 2021].

Christoph Reich 122 Dec 12, 2022