Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

This repository is built upon BEiT, thanks very much!

Now, we only implement the pretrain process according to the paper, and can't guarantee the performance reported in the paper can be reproduced!

Difference

At the same time, shuffle and unshuffle operations don't seem to be directly accessible in pytorch, so we use another method to realize this process:

  • For shuffle, we used the method of randomly generating mask-map (14x14) in BEiT, where mask=0 illustrates keep the token, mask=1 denotes drop the token (not participating caculation in Encoder). Then all visible tokens (mask=0) are put into encoder network.
  • For unshuffle, we get the postion embeddings (with adding the shared mask token) of all mask tokens according to the mask-map and then concate them with the visible tokens (from encoder), and put them into the decoder network to recontrust.

TODO

  • implement the finetune process
  • reuse the model in modeling_pretrain.py
  • caculate the normalized pixels target
  • add the cls token in the encoder
  • ...

Setup

pip install -r requirements.txt

Run

# Set the path to save checkpoints
OUTPUT_DIR='output/'
# path to imagenet-1k train set
DATA_PATH='../ImageNet_ILSVRC2012/train'


OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_mae_pretraining.py \
        --data_path ${DATA_PATH} \
        --mask_ratio 0.75 \
        --model pretrain_mae_base_patch16_224 \
        --batch_size 128 \
        --opt_betas 0.9 0.95 \
        --warmup_epochs 40 \
        --epochs 1600 \
        --output_dir ${OUTPUT_DIR}

Note: the pretrain result is on the way ~

Comments
  • bugs

    bugs

    MAE-pytorch/modeling_pretrain.py", line 296, in pretrain_mae_base_patch16_224
        **kwargs)
    TypeError: __init__() got an unexpected keyword argument 'num_classes'
    
    opened by cuge1995 9
  • TypeError: __init__() got an unexpected keyword argument 'in_chans'

    TypeError: __init__() got an unexpected keyword argument 'in_chans'

    分享一个issue,我一开始用的是11.22更新的requirements.txt,会报如下错误:

    File"/MAE-pytorch/modeling_pretrain.py", line 319, in pretrain_mae_base_patch16_224 **kwargs) **kwargs) TypeError: init() got an unexpected keyword argument 'in_chans'

    我发现好像是timm==0.3.2导致的,后来升至0.4.12可以解决这个问题。 测试了V100 (CentOS) 和 A100 (Ubuntu),都存在这个issue。 卸载timm好像也一样可以跑……我第一次看CV的code,了解的还不是很多。 顺便想问一下你们在V100选用的batch size是64吗

    opened by Celestial-Bai 7
  • Visual loading model error

    Visual loading model error

    Hello, when I was carrying out the visualization, the loading model failed. Could you please tell me the reason? image I fine-tune the model with pre-training, visualizing calls for fine-tuning the model image

    opened by mouxinyue1 6
  • where is the code of freezing the blocks that you don't want to finetune?

    where is the code of freezing the blocks that you don't want to finetune?

    Hello, thanks for your implementation.

    I have read the main part of your code, but I didn't find the code that controls the Partial fine-tuning. Could you please tell me where is that part in "run_class_finetuning.py", "modeling_finetune.py" or anywhere else?

    Wainting for your reply, thank you.

    opened by A-zhudong 6
  • Positional embedding not stored in checkpoints - problem for tuning/inference at higher resolution

    Positional embedding not stored in checkpoints - problem for tuning/inference at higher resolution

    Hi,

    I'm very impressed with the quick reproduction, nice work!

    I have tried running inference with the provided models and noticed that the current checkpoints do not contain the positional embedding. This is not an issue when running on the same resolution (224,224). However, it makes it tricky to run inference at higher resolutions, since there is no positional encoding to interpolate.

    I have been using the following hard-coded workaround but as far as I can see, the only solution is to change the model so that the positional encoding is stored as part of the checkpoint. Here is the hard-coded solution for loading current models and running tuning/inference:

    # this replaces the code from line 334 in run_class_finetuning.py
    
            # Maybe interpolate position embedding
            old_n_positions = int((224/16)**2)
            if model.pos_embed.shape[1] != old_n_positions:
                embedding_size = model.pos_embed.shape[-1]
                old_pos_embed = modeling_finetune.get_sinusoid_encoding_table(old_n_positions, embedding_size)
                num_patches = model.patch_embed.num_patches
                num_extra_tokens = model.pos_embed.shape[-2] - num_patches
                assert num_extra_tokens == 0, "No support for class tokens"
                # height (== width) for the checkpoint position embedding
                orig_size = int((old_pos_embed.shape[-2] - num_extra_tokens) ** 0.5)
                # height (== width) for the new position embedding
                new_size = int(num_patches ** 0.5)
                # class_token and dist_token are kept unchanged
                if orig_size != new_size:
                    print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
                    extra_tokens = old_pos_embed[:, :num_extra_tokens]
                    # only the position tokens are interpolated
                    pos_tokens = old_pos_embed[:, num_extra_tokens:]
                    pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
                    pos_tokens = torch.nn.functional.interpolate(
                        pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
                    pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
                    new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                    model.pos_embed = new_pos_embed
    
    opened by atonderski 6
  • How to process the dataset...

    How to process the dataset...

    I want to know how to process the dataset.... There are a lot of tar files.

    Is there any commands or processing python file? It is my firtst time to do this....

    Sorry for my stupid question..

    opened by wanng-ide 4
  • Grad Norm Becomes Inf

    Grad Norm Becomes Inf

    image On two gpus.

    Epoch: [24] [1230/1251] eta: 0:00:06 lr: 0.000375 min_lr: 0.000375 loss: 0.6870 (0.6848) loss_scale: 2097152.0000 (2046895.3111) weight_decay: 0.0500 (0.0500) grad_norm: 0.0929 (0.0969) time: 0.3023 data: 0.0010 max mem: 8361 Epoch: [24] [1240/1251] eta: 0:00:03 lr: 0.000375 min_lr: 0.000375 loss: 0.6877 (0.6848) loss_scale: 2097152.0000 (2047300.2804) weight_decay: 0.0500 (0.0500) grad_norm: 0.0942 (0.0971) time: 0.2731 data: 0.0018 max mem: 8361 Epoch: [24] [1250/1251] eta: 0:00:00 lr: 0.000375 min_lr: 0.000375 loss: 0.6856 (0.6849) loss_scale: 2097152.0000 (2047698.7754) weight_decay: 0.0500 (0.0500) grad_norm: 0.0942 (0.0971) time: 0.2560 data: 0.0012 max mem: 8361 Epoch: [24] Total time: 0:06:23 (0.3067 s / it) Averaged stats: lr: 0.000375 min_lr: 0.000375 loss: 0.6856 (0.6851) loss_scale: 2097152.0000 (2047698.7754) weight_decay: 0.0500 (0.0500) grad_norm: 0.0942 (0.0971) Epoch: [25] [ 0/1251] eta: 1:25:25 lr: 0.000375 min_lr: 0.000375 loss: 0.6770 (0.6770) loss_scale: 2097152.0000 (2097152.0000) weight_decay: 0.0500 (0.0500) grad_norm: 0.0918 (0.0918) time: 4.0974 data: 3.7792 max mem: 8361 Epoch: [25] [ 10/1251] eta: 0:13:50 lr: 0.000375 min_lr: 0.000375 loss: 0.6854 (0.6838) loss_scale: 2097152.0000 (2097152.0000) weight_decay: 0.0500 (0.0500) grad_norm: 0.0910 (0.0949) time: 0.6694 data: 0.3704 max mem: 8361

    How does the phenomenon occur?

    opened by TiankaiHang 4
  • help

    help

    Hello, I'm a novice. I want to ask. Download directly ImageNet_ILSVRC2012 dataset, can you get the pre training model according to your pre train process? If not, how can I get it. Can two GPUs pre train themselves?

    opened by huanghaiyun-ui 4
  • hello, a little question about this code.

    hello, a little question about this code.

    I am beginner and crazy about this MAE, but about the pretrained model--pretrain_mae_vit_base_mask_0.75_400e.pth downloaded from BaiDu yun, I really wanna know where to load it in the file. Can it be used in modeling_pretrain.py?

    opened by thebestYezhang 4
  • Managing 512*512 input size

    Managing 512*512 input size

    Thanks for your code! It's really a great work! I want to train MAE on a dataset with 512*512 image size, and how can I adjust the model structure? Thanks for your help.

    opened by HardworkingLittlequ 4
  • Run with pretrain_mae_vit_base_mask_0.75_400e.pth get an model initialization error

    Run with pretrain_mae_vit_base_mask_0.75_400e.pth get an model initialization error

    Traceback (most recent call last): File "run_mae_vis.py", line 138, in main(opts) File "run_mae_vis.py", line 79, in main model = get_model(args) File "run_mae_vis.py", line 63, in get_model model = create_model( File "/home/ub/miniconda3/envs/torch1.8/lib/python3.6/site-packages/timm/models/factory.py", line 57, in create_model model = create_fn(**model_args, **kwargs) File "/home/ub/bwj/MAE-pytorch/modeling_pretrain.py", line 317, in pretrain_mae_base_patch16_224 **kwargs) TypeError: init() got an unexpected keyword argument 'num_classes'

    opened by Owen-Fish 4
  • RuntimeError: Given normalized_shape=[768], expected input with shape [*, 768], but got input of size[12]

    RuntimeError: Given normalized_shape=[768], expected input with shape [*, 768], but got input of size[12]

    I'm running mae's finetune"return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) RuntimeError: Given normalized_shape=[768], expected input with shape [*, 768], but got input of size[12]",Please tell me what the problem is.

    opened by zhengzaidenglu 0
  • Which dataset is used for the released pretrained model?

    Which dataset is used for the released pretrained model?

    Hey,

    Thanks very much for this excellent repo, it is definitely worth hundreds of thousands of stars! :)

    I wonder which datasets are used for the released pretrained model? Since it works too well to be true for image recovery from randomly masked inputs from FFHQ datasets, especially the backgrounds which has no clues at all for the input part. So is it the case that the model actually saw the data somewhere?

    Thank you very much in advance!!

    opened by AArchLichKing 0
  • Bad transfer learning result while fine tuning in iNaturalist 2019 which is not IN1K

    Bad transfer learning result while fine tuning in iNaturalist 2019 which is not IN1K

    Dear Author,

    Firstly thanks and appreciated for your great contribution.

    While fine tuning with IN1K base on the pre-train model which also trained with IN1K, the result is similar to the paper's as follows: Screenshot from 2021-12-30 14-13-49

    But if I fine tune with iNaturalist with the same pre-train model and same finetune parameters listed in your github page, the result is really bad as follows: Screenshot from 2021-12-30 14-14-16

    So, what do you think of the possible reason for me? Looking forward to your reply, Thanks in advance !

    BTW, the picture amount of iNaturalist is about 260,000, icludes 1010 classes. The train data and val data is not separated in iNaturalist, I divided follow the ratio of IN1K(96% for train, 4% for val).

    In addition, do you have the plan to implement fine tuning code of object detection and semantic segmentation ? If yes, how much longer we need wait for ? Thanks again !

    opened by cssddnnc9527 7
Owner
Zhiliang Peng
Zhiliang Peng
An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners

An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners This is a coarse version for MAE, only make the pretrain model, the fine

FlyEgle 214 Dec 29, 2022
Re-implememtation of MAE (Masked Autoencoders Are Scalable Vision Learners) using PyTorch.

mae-repo PyTorch re-implememtation of "masked autoencoders are scalable vision learners". In this repo, it heavily borrows codes from codebase https:/

Peng Qiao 1 Dec 14, 2021
VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training

Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training [Arxiv] VideoMAE: Masked Autoencoders are Data-Efficient Learne

Multimedia Computing Group, Nanjing University 697 Jan 7, 2023
ConvMAE: Masked Convolution Meets Masked Autoencoders

ConvMAE ConvMAE: Masked Convolution Meets Masked Autoencoders Peng Gao1, Teli Ma1, Hongsheng Li2, Jifeng Dai3, Yu Qiao1, 1 Shanghai AI Laboratory, 2 M

Alpha VL Team of Shanghai AI Lab 345 Jan 8, 2023
Code and pre-trained models for MultiMAE: Multi-modal Multi-task Masked Autoencoders

MultiMAE: Multi-modal Multi-task Masked Autoencoders Roman Bachmann*, David Mizrahi*, Andrei Atanov, Amir Zamir Website | arXiv | BibTeX Official PyTo

Visual Intelligence & Learning Lab, Swiss Federal Institute of Technology (EPFL) 385 Jan 6, 2023
Contains code for the paper "Vision Transformers are Robust Learners".

Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin

Sayak Paul 103 Jan 5, 2023
Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

MobileViT RegNet Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TR

Hong-Jia Chen 91 Dec 2, 2022
MADE (Masked Autoencoder Density Estimation) implementation in PyTorch

pytorch-made This code is an implementation of "Masked AutoEncoder for Density Estimation" by Germain et al., 2015. The core idea is that you can turn

Andrej 498 Dec 30, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
The official codes of "Semi-supervised Models are Strong Unsupervised Domain Adaptation Learners".

SSL models are Strong UDA learners Introduction This is the official code of paper "Semi-supervised Models are Strong Unsupervised Domain Adaptation L

Yabin Zhang 26 Dec 26, 2022
Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)

Self-Supervised Models are Continual Learners This is the official repository for the paper: Self-Supervised Models are Continual Learners Enrico Fini

Enrico Fini 73 Dec 18, 2022
Unofficial implementation of MLP-Mixer: An all-MLP Architecture for Vision

MLP-Mixer: An all-MLP Architecture for Vision This repo contains PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision. Usage : impo

Rishikesh (ऋषिकेश) 175 Dec 23, 2022
Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Swin-Transformer-Tensorflow A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Sh

null 52 Dec 29, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 5, 2022
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

Grid AI Labs 48 Dec 12, 2022
VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

Hao Tan 74 Dec 3, 2022
EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

Atsuki Yamaguchi 31 Nov 18, 2022