An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners

This is a coarse version for MAE, only make the pretrain model, the finetune and linear is comming soon.

1. Introduction

This repo is the MAE-vit model which impelement with pytorch, no reference any reference code so this is a non-official version. Because of the limitation of time and machine, I only trained the vit-tiny model for encoder. mae

2. Enveriments

  • python 3.7+
  • pytorch 1.7.1
  • pillow
  • timm
  • opencv-python

3. Model Config

Pretrain Config

  • BaseConfig
    img_size = 224,
    patch_size = 16,
  • Encoder The encoder if follow the Vit-tiny model config
    encoder_dim = 192,
    encoder_depth = 12,
    encoder_heads = 3,
  • Decoder The decoder is followed the kaiming paper config.
    decoder_dim = 512,
    decoder_depth = 8,
    decoder_heads = 16, 
  • Mask
    1. We use the shuffle patch after Sin-Cos position embeeding for encoder.
    2. Mask the shuffle patch, keep the mask index.
    3. Unshuffle the mask patch and combine with the encoder embeeding before the position embeeding for decoder.
    4. Restruction decoder embeeidng by convtranspose.
    5. Build the mask map with mask index for cal the loss(only consider the mask patch).

Finetune Config

Wait for the results

TODO:

  • Finetune Trainig
  • Linear Training

4. Results

decoder Restruction the imagenet validation image from pretrain model, compare with the kaiming results, restruction quality is less than he. May be the encoder model is too small TT.

The Mae-Vit-tiny pretrain models is here, you can download to test the restruction result. Put the ckpt in weights folder.

5. Training & Inference

  • dataset prepare

    /data/home/imagenet/xxx.jpeg, 0
    /data/home/imagenet/xxx.jpeg, 1
    ...
    /data/home/imagenet/xxx.jpeg, 999
    
  • Training

    1. Pretrain

      #!/bin/bash
      OMP_NUM_THREADS=1
      MKL_NUM_THREADS=1
      export OMP_NUM_THREADS
      export MKL_NUM_THREADS
      cd MAE-Pytorch;
      CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch --nproc_per_node 8 train_mae.py \
      --batch_size 256 \
      --num_workers 32 \
      --lr 1.5e-4 \
      --optimizer_name "adamw" \
      --cosine 1 \
      --max_epochs 300 \
      --warmup_epochs 40 \
      --num-classes 1000 \
      --crop_size 224 \
      --patch_size 16 \
      --color_prob 0.0 \
      --calculate_val 0 \
      --weight_decay 5e-2 \
      --lars 0 \
      --mixup 0.0 \
      --smoothing 0.0 \
      --train_file $train_file \
      --val_file $val_file \
      --checkpoints-path $ckpt_folder \
      --log-dir $log_folder
    2. Finetune TODO:

      • training
    3. Linear TODO:

      • training
  • Inference

    1. pretrian
    python mae_test.py --test_image xxx.jpg --ckpt weights.pth
    1. classification TODO:
      • training

6. TODO

  • VIT-BASE model training.
  • SwinTransformers for MAE.
  • Finetune & Linear training.

Finetune is trainig, the weights may be comming soon.

Comments
  • Maybe have some bugs in finetune

    Maybe have some bugs in finetune

    It seems that the losses are different between training_mae and finetune.When training mae and finetune, there is no distinction between the different cases in the forward function.But when calculating loss, finetune receives two variables while Training_mae only receives one.

    ` with autocast():

            if args.finetune:
    
                # one variable?
    
                outputs = model(inputs)
    
                losses = criterion(outputs, targets)
    
            else:
    
                # two variables?
    
                outputs, mask_index = model(inputs)
    
                print('shapex', outputs.shape, type(mask_index))
    
                mask = build_mask(mask_index, args.patch_size, args.crop_size)
    
                losses = criterion(outputs, inputs, mask)
    

    `

    Code in model forward: ` def forward(self, x): # batch, c, h, w norm_embeeding, sample_index, mask_index = self.Encoder.autoencoder(x) proj_embeeding = self.proj(norm_embeeding) decode_embeeding = self.Decoder.decoder(proj_embeeding, sample_index, mask_index) outputs = self.restruction(decode_embeeding)

        cls_token = outputs[:, 0, :]
        image_token = outputs[:, 1:, :] # (b, num_patches, patches_vector)
        # cal the mask patches normalization Independent
        image_norm_token = self.patch_norm(image_token)
        n, l, dim = image_norm_token.shape
        image_norm_token = image_norm_token.view(-1, self.num_patch[0], self.num_patch[1], dim).permute(0, 3, 1, 2)
        restore_image = self.unconv(image_norm_token)
        # return same variables in both training_mae and finetune
        return restore_image, mask_index`
    
    opened by Nioolek 3
  • error occur

    error occur

    Hi~ I'm really impressed with your code. I want to restore the hidden part of the face. So, you want to train with your code, but it is inferenced with a pretrained model. However, in the case of train, many problems arise. advice please

    CUDA_VISIBLE_DEVICES=0,1 python -W ignore -m torch.distributed.launch --nproc_per_node 8 train_mae.py rank: 1 / 2 rank: 4 / 2 rank: 0 / 2 rank: 3 / 2 rank: 5 / 2 rank: 2 / 2 rank: 6 / 2 rank: 7 / 2 Traceback (most recent call last): File "train_mae.py", line 692, in main_worker(args) File "train_mae.py", line 205, in main_worker torch.cuda.set_device(args.local_rank) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/cuda/init.py", line 261, in set_device torch._C._cuda_setDevice(device) RuntimeError: CUDA error: invalid device ordinal Traceback (most recent call last): File "train_mae.py", line 692, in main_worker(args) File "train_mae.py", line 205, in main_worker torch.cuda.set_device(args.local_rank) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/cuda/init.py", line 261, in set_device torch._C._cuda_setDevice(device) RuntimeError: CUDA error: invalid device ordinal Traceback (most recent call last): File "train_mae.py", line 692, in main_worker(args) File "train_mae.py", line 205, in main_worker torch.cuda.set_device(args.local_rank) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/cuda/init.py", line 261, in set_device torch._C._cuda_setDevice(device) RuntimeError: CUDA error: invalid device ordinal Traceback (most recent call last): Traceback (most recent call last): File "train_mae.py", line 692, in File "train_mae.py", line 692, in Traceback (most recent call last): File "train_mae.py", line 692, in main_worker(args) main_worker(args) File "train_mae.py", line 205, in main_worker File "train_mae.py", line 205, in main_worker torch.cuda.set_device(args.local_rank) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/cuda/init.py", line 261, in set_device torch.cuda.set_device(args.local_rank) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/cuda/init.py", line 261, in set_device torch._C._cuda_setDevice(device) RuntimeError: CUDA error: invalid device ordinal torch._C._cuda_setDevice(device) RuntimeError: CUDA error: invalid device ordinal main_worker(args) File "train_mae.py", line 205, in main_worker torch.cuda.set_device(args.local_rank) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/cuda/init.py", line 261, in set_device torch._C._cuda_setDevice(device) RuntimeError: CUDA error: invalid device ordinal Killing subprocess 1520881 Killing subprocess 1520882 Killing subprocess 1520883 Killing subprocess 1520884 Killing subprocess 1520885 Killing subprocess 1520886 Killing subprocess 1520889 Killing subprocess 1520893 Traceback (most recent call last): File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/distributed/launch.py", line 340, in main() File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/distributed/launch.py", line 326, in main sigkill_handler(signal.SIGTERM, None) # not coming back File "/home/vimlab/anaconda3/envs/stylegan2_pytorch/lib/python3.6/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd) subprocess.CalledProcessError: Command '['/home/vimlab/anaconda3/envs/stylegan2_pytorch/bin/python', '-u', 'train_mae.py', '--local_rank=7']' returned non-zero exit status 1.

    opened by leeisack 1
  • How to apply

    How to apply "mae" to a detection pipeline?

    Thanks for sharing the code; I have a question or help me with how I can use mae as a backbone for the object detection framework? If you have any guidance, I'd appreciate it if you could please help me with that?

    opened by zobeirraisi 1
  • load_state_dict, size mismatch

    load_state_dict, size mismatch

    Hi I’m not good at pytorch modeling, but how can I download the pretrained weight for inference.py?

    I downloaded pretrained weight as you mentioned below, but I can’t load correctly.

    Vit-Tiny/16 pretrain models is here Vit-Base/16 pretrain models is here

    error message is

    load_state_dict, size mismatch for cls_token:copying a param with shape torch.Size([1,1,192)] from check point, the shape in current model is torch.Size([1,1,768]).

    It seems that /model/Transformers/VIT/mae.py line around 200 is incorrect. thank you for reading.

    opened by ichiyasa0308 1
  • What are the rules for setting the parameters of vit-tiny's decoder?

    What are the rules for setting the parameters of vit-tiny's decoder?

    Thanks for your work! I’m pretraining the vit-tiny for my own dataset, but i can not determine the setting for decoder's parameters (depth/embed_dim/num_heads), just consistent with vit-base/large/huge or choose some smaller value to make a lightweight decoder?

    opened by zzzzzzyang 1
Owner
FlyEgle
JOYY AI GROUP - Machine Learning Engineer(Computer Vision)
FlyEgle
An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners

An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners This is a coarse version for MAE, only make the pretrain model, the fine

FlyEgle 214 Dec 29, 2022
Re-implememtation of MAE (Masked Autoencoders Are Scalable Vision Learners) using PyTorch.

mae-repo PyTorch re-implememtation of "masked autoencoders are scalable vision learners". In this repo, it heavily borrows codes from codebase https:/

Peng Qiao 1 Dec 14, 2021
VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training

Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training [Arxiv] VideoMAE: Masked Autoencoders are Data-Efficient Learne

Multimedia Computing Group, Nanjing University 697 Jan 7, 2023
ConvMAE: Masked Convolution Meets Masked Autoencoders

ConvMAE ConvMAE: Masked Convolution Meets Masked Autoencoders Peng Gao1, Teli Ma1, Hongsheng Li2, Jifeng Dai3, Yu Qiao1, 1 Shanghai AI Laboratory, 2 M

Alpha VL Team of Shanghai AI Lab 345 Jan 8, 2023
Code and pre-trained models for MultiMAE: Multi-modal Multi-task Masked Autoencoders

MultiMAE: Multi-modal Multi-task Masked Autoencoders Roman Bachmann*, David Mizrahi*, Andrei Atanov, Amir Zamir Website | arXiv | BibTeX Official PyTo

Visual Intelligence & Learning Lab, Swiss Federal Institute of Technology (EPFL) 385 Jan 6, 2023
Contains code for the paper "Vision Transformers are Robust Learners".

Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin

Sayak Paul 103 Jan 5, 2023
MADE (Masked Autoencoder Density Estimation) implementation in PyTorch

pytorch-made This code is an implementation of "Masked AutoEncoder for Density Estimation" by Germain et al., 2015. The core idea is that you can turn

Andrej 498 Dec 30, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
The official codes of "Semi-supervised Models are Strong Unsupervised Domain Adaptation Learners".

SSL models are Strong UDA learners Introduction This is the official code of paper "Semi-supervised Models are Strong Unsupervised Domain Adaptation L

Yabin Zhang 26 Dec 26, 2022
Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)

Self-Supervised Models are Continual Learners This is the official repository for the paper: Self-Supervised Models are Continual Learners Enrico Fini

Enrico Fini 73 Dec 18, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 5, 2022
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

Grid AI Labs 48 Dec 12, 2022
VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

Hao Tan 74 Dec 3, 2022
EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

Atsuki Yamaguchi 31 Nov 18, 2022
The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

AI2 114 Jan 6, 2023
SimMIM: A Simple Framework for Masked Image Modeling

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

Microsoft 181 Dec 10, 2021
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

Picsart AI Research (PAIR) 186 Dec 30, 2022