PyTorch implementation of a collections of scalable Video Transformer Benchmarks.

Overview

PyTorch implementation of Video Transformer Benchmarks

This repository is mainly built upon Pytorch and Pytorch-Lightning. We wish to maintain a collections of scalable video transformer benchmarks, and discuss the training recipes of how to train a big video transformer model.

Now, we implement the TimeSformer and ViViT. And we have pre-trained the TimeSformer-B on Kinetics600, but still can't guarantee the performance reported in the paper. However, we find some relevant hyper-parameters which may help us to reach the target performance.

Table of Contents

  1. Difference
  2. TODO
  3. Setup
  4. Usage
  5. Result
  6. Acknowledge
  7. Contribution

Difference

In order to share the basic divided spatial-temporal attention module to different video transformer, we make some changes in the following apart.

1. Position embedding

We split the position embedding from R(nt*h*w×d) mentioned in the ViViT paper into R(nh*w×d) and R(nt×d) to stay the same as TimeSformer.

2. Class token

In order to make clear whether to add the class_token into the module forward computation, we only compute the interaction between class_token and query when the current layer is the last layer (except FFN) of each transformer block.

3. Initialize from the pre-trained model

  • Tokenization: the token embedding filter can be chosen either Conv2D or Conv3D, and the initializing weights of Conv3D filters from Conv2D can be replicated along temporal dimension and averaging them or initialized with zeros along the temporal positions except at the center t/2.
  • Temporal MSA module weights: one can choose to copy the weights from spatial MSA module or initialize all weights with zeros.
  • Initialize from the MAE pre-trained model provided by ZhiLiang, where the class_token that does not appear in the MAE pre-train model is initialized from truncated normal distribution.
  • Initialize from the ViT pre-trained model can be found here.

TODO

  • add more TimeSformer and ViViT variants pre-trained weights.
    • A larger version and other operation types.
  • add linear prob and partial fine-tune.
    • Make available to transfer the pre-trained model to downstream task.
  • add more scalable Video Transformer benchmarks.
    • We will also extend to multi-modality version, e.g Perceiver is coming soon.
  • add more diverse objective functions.
    • Pre-train on larger dataset through the dominated self-supervised methods, e.g Contrastive Learning and MAE.

Setup

pip install -r requirements.txt

Usage

Training

# path to Kinetics600 train set
TRAIN_DATA_PATH='/path/to/Kinetics600/train_list.txt'
# path to root directory
ROOT_DIR='/path/to/work_space'

python model_pretrain.py \
	-lr 0.005 \
	-pretrain 'vit' \
	-epoch 15 \
	-batch_size 8 \
	-num_class 600 \
	-frame_interval 32 \
	-root_dir ROOT_DIR \
	-train_data_path TRAIN_DATA_PATH

The minimal folder structure will look like as belows.

root_dir
├── pretrain_model
│   ├── pretrain_mae_vit_base_mask_0.75_400e.pth
│   ├── vit_base_patch16_224.pth
├── results
│   ├── experiment_tag
│   │   ├── ckpt
│   │   ├── log

Inference

# path to Kinetics600 pre-trained model
PRETRAIN_PATH='/path/to/pre-trained model'
# path to the test video sample
VIDEO_PATH='/path/to/video sample'

python model_inference.py \
	-pretrain PRETRAIN_PATH \
	-video_path VIDEO_PATH \
	-num_frames 8 \
	-frame_interval 32 \

Result

Kinetics-600

1. Model Zoo

name pretrain epochs num frames spatial crop top1_acc top5_acc weight log
TimeSformer-B ImageNet-21K 15e 8 224 78.4 93.6 Google drive or BaiduYun(code: yr4j) log

2. Train Recipe(ablation study)

2.1 Acc

operation top1_acc top5_acc top1_acc (three crop)
base 68.2 87.6 -
+ frame_interval 4 -> 16 (span more time) 72.9(+4.7) 91.0(+3.4) -
+ RandomCrop, flip (overcome overfit) 75.7(+2.8) 92.5(+1.5) -
+ batch size 16 -> 8 (more iterations) 75.8(+0.1) 92.4(-0.1) -
+ frame_interval 16 -> 24 (span more time) 77.7(+1.9) 93.3(+0.9) 78.4
+ frame_interval 24 -> 32 (span more time) 78.4(+0.7) 94.0(+0.7) 79.1

tips: frame_interval and data augment counts for the validation accuracy.


2.2 Time

operation epoch_time
base (start with DDP) 9h+
+ speed up training recipes 1h+
+ switch from get_batch first to sample_Indice first 0.5h
+ batch size 16 -> 8 33.32m
+ num_workers 8 -> 4 35.52m
+ frame_interval 16 -> 24 44.35m

tips: Improve the frame_interval will drop a lot on time performance.

1.speed up training recipes:

  • More GPU device.
  • pin_memory=True.
  • Avoid CPU->GPU Device transfer (such as .item(), .numpy(), .cpu() operations on tensor or log to disk).

2.get_batch first means that we firstly read all frames through the video reader, and then get the target slice of frames, so it largely slow down the data-loading speed.


Acknowledge

this repo is built on top of Pytorch-Lightning, decord and kornia. I also learn many code designs from MMaction2. I thank the authors for releasing their code.

Contribution

I look forward to seeing one can provide some ideas about the repo, please feel free to report it in the issue, or even better, submit a pull request.

And your star is my motivation, thank u~

Comments
  • Example training command/performance

    Example training command/performance

    Trying to get top1_acc of >78 as shown in the example log.

    Do we know the settings and dataset used for training?

    I am training on K400 and using the command in the example: python model_pretrain.py
    -lr 0.005
    -pretrain 'vit'
    -objective 'supervised'
    -epoch 30
    -batch_size 8
    -num_workers 4
    -arch 'timesformer'
    -attention_type 'divided_space_time'
    -num_frames 8 \ -frame_interval 32
    -num_class 400
    -optim_type 'sgd'
    -lr_schedule 'cosine'
    -root_dir ROOT_DIR
    -train_data_path TRAIN_DATA_PATH
    -val_data_path VAL_DATA_PATH

    I am unable to get above >73. Increasing frame_interval does not help.

    Curious what I can do to get similar performance.

    opened by Enclavet 7
  • Request code for finetune with self-supervised pretrained weights

    Request code for finetune with self-supervised pretrained weights

    I tried to do self-supervised experiments with your code, but ran into a lot of problems during the fine-tuning stage. Can you share your MVIT finetune code? Thank you!

    opened by WHlTE-N0lSE 4
  • Missing keys in demo notebook

    Missing keys in demo notebook

    Hi, thank you for sharing your work.

    When I follow the instructions in the notebook file (VideoTransformer_demo.ipynb), I got trouble loading the pre-trained weights of ViViT model.

    After downloading and placing the "./vivit_model.pth" file, I was able to instantiate the ViViT model. However, the log says that there are many missing keys in the given pth file.

    Is it the desired behavior? or should I do some preprocessing to match the parameter name?

    This is the output after parameter loading.

    load model finished, the missing key of transformer is:['transformer_layers.0.layers.0.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.0.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.0.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.0.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.1.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.1.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.1.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.1.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.2.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.2.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.2.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.2.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.3.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.3.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.3.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.3.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.4.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.4.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.4.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.4.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.5.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.5.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.5.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.5.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.6.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.6.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.6.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.6.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.7.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.7.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.7.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.7.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.8.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.8.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.8.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.8.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.9.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.9.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.9.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.9.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.10.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.10.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.10.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.10.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.11.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.11.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.11.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.11.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.0.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.0.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.0.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.0.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.1.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.1.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.1.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.1.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.2.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.2.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.2.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.2.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.3.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.3.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.3.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.3.attentions.0.attn.proj.bias'], cls is:[]

    Thank you in advance!

    +edit) FYI, these are the unexpected keys from the load_state_dict(). transformer unexpected: ['cls_head.weight', 'cls_head.bias', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.bias']

    classification head unexpected: ['cls_token', 'pos_embed', 'time_embed', 'patch_embed.projection.weight', 'patch_embed.projection.bias', 'transformer_layers.0.layers.0.attentions.0.norm.weight', 'transformer_layers.0.layers.0.attentions.0.norm.bias', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.0.ffns.0.norm.weight', 'transformer_layers.0.layers.0.ffns.0.norm.bias', 'transformer_layers.0.layers.0.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.0.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.0.ffns.0.layers.1.weight', 'transformer_layers.0.layers.0.ffns.0.layers.1.bias', 'transformer_layers.0.layers.1.attentions.0.norm.weight', 'transformer_layers.0.layers.1.attentions.0.norm.bias', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.1.ffns.0.norm.weight', 'transformer_layers.0.layers.1.ffns.0.norm.bias', 'transformer_layers.0.layers.1.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.1.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.1.ffns.0.layers.1.weight', 'transformer_layers.0.layers.1.ffns.0.layers.1.bias', 'transformer_layers.0.layers.2.attentions.0.norm.weight', 'transformer_layers.0.layers.2.attentions.0.norm.bias', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.2.ffns.0.norm.weight', 'transformer_layers.0.layers.2.ffns.0.norm.bias', 'transformer_layers.0.layers.2.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.2.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.2.ffns.0.layers.1.weight', 'transformer_layers.0.layers.2.ffns.0.layers.1.bias', 'transformer_layers.0.layers.3.attentions.0.norm.weight', 'transformer_layers.0.layers.3.attentions.0.norm.bias', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.3.ffns.0.norm.weight', 'transformer_layers.0.layers.3.ffns.0.norm.bias', 'transformer_layers.0.layers.3.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.3.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.3.ffns.0.layers.1.weight', 'transformer_layers.0.layers.3.ffns.0.layers.1.bias', 'transformer_layers.0.layers.4.attentions.0.norm.weight', 'transformer_layers.0.layers.4.attentions.0.norm.bias', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.4.ffns.0.norm.weight', 'transformer_layers.0.layers.4.ffns.0.norm.bias', 'transformer_layers.0.layers.4.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.4.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.4.ffns.0.layers.1.weight', 'transformer_layers.0.layers.4.ffns.0.layers.1.bias', 'transformer_layers.0.layers.5.attentions.0.norm.weight', 'transformer_layers.0.layers.5.attentions.0.norm.bias', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.5.ffns.0.norm.weight', 'transformer_layers.0.layers.5.ffns.0.norm.bias', 'transformer_layers.0.layers.5.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.5.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.5.ffns.0.layers.1.weight', 'transformer_layers.0.layers.5.ffns.0.layers.1.bias', 'transformer_layers.0.layers.6.attentions.0.norm.weight', 'transformer_layers.0.layers.6.attentions.0.norm.bias', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.6.ffns.0.norm.weight', 'transformer_layers.0.layers.6.ffns.0.norm.bias', 'transformer_layers.0.layers.6.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.6.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.6.ffns.0.layers.1.weight', 'transformer_layers.0.layers.6.ffns.0.layers.1.bias', 'transformer_layers.0.layers.7.attentions.0.norm.weight', 'transformer_layers.0.layers.7.attentions.0.norm.bias', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.7.ffns.0.norm.weight', 'transformer_layers.0.layers.7.ffns.0.norm.bias', 'transformer_layers.0.layers.7.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.7.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.7.ffns.0.layers.1.weight', 'transformer_layers.0.layers.7.ffns.0.layers.1.bias', 'transformer_layers.0.layers.8.attentions.0.norm.weight', 'transformer_layers.0.layers.8.attentions.0.norm.bias', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.8.ffns.0.norm.weight', 'transformer_layers.0.layers.8.ffns.0.norm.bias', 'transformer_layers.0.layers.8.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.8.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.8.ffns.0.layers.1.weight', 'transformer_layers.0.layers.8.ffns.0.layers.1.bias', 'transformer_layers.0.layers.9.attentions.0.norm.weight', 'transformer_layers.0.layers.9.attentions.0.norm.bias', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.9.ffns.0.norm.weight', 'transformer_layers.0.layers.9.ffns.0.norm.bias', 'transformer_layers.0.layers.9.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.9.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.9.ffns.0.layers.1.weight', 'transformer_layers.0.layers.9.ffns.0.layers.1.bias', 'transformer_layers.0.layers.10.attentions.0.norm.weight', 'transformer_layers.0.layers.10.attentions.0.norm.bias', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.10.ffns.0.norm.weight', 'transformer_layers.0.layers.10.ffns.0.norm.bias', 'transformer_layers.0.layers.10.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.10.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.10.ffns.0.layers.1.weight', 'transformer_layers.0.layers.10.ffns.0.layers.1.bias', 'transformer_layers.0.layers.11.attentions.0.norm.weight', 'transformer_layers.0.layers.11.attentions.0.norm.bias', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.11.ffns.0.norm.weight', 'transformer_layers.0.layers.11.ffns.0.norm.bias', 'transformer_layers.0.layers.11.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.11.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.11.ffns.0.layers.1.weight', 'transformer_layers.0.layers.11.ffns.0.layers.1.bias', 'transformer_layers.1.layers.0.attentions.0.norm.weight', 'transformer_layers.1.layers.0.attentions.0.norm.bias', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.0.ffns.0.norm.weight', 'transformer_layers.1.layers.0.ffns.0.norm.bias', 'transformer_layers.1.layers.0.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.0.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.0.ffns.0.layers.1.weight', 'transformer_layers.1.layers.0.ffns.0.layers.1.bias', 'transformer_layers.1.layers.1.attentions.0.norm.weight', 'transformer_layers.1.layers.1.attentions.0.norm.bias', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.1.ffns.0.norm.weight', 'transformer_layers.1.layers.1.ffns.0.norm.bias', 'transformer_layers.1.layers.1.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.1.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.1.ffns.0.layers.1.weight', 'transformer_layers.1.layers.1.ffns.0.layers.1.bias', 'transformer_layers.1.layers.2.attentions.0.norm.weight', 'transformer_layers.1.layers.2.attentions.0.norm.bias', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.2.ffns.0.norm.weight', 'transformer_layers.1.layers.2.ffns.0.norm.bias', 'transformer_layers.1.layers.2.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.2.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.2.ffns.0.layers.1.weight', 'transformer_layers.1.layers.2.ffns.0.layers.1.bias', 'transformer_layers.1.layers.3.attentions.0.norm.weight', 'transformer_layers.1.layers.3.attentions.0.norm.bias', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.3.ffns.0.norm.weight', 'transformer_layers.1.layers.3.ffns.0.norm.bias', 'transformer_layers.1.layers.3.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.3.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.3.ffns.0.layers.1.weight', 'transformer_layers.1.layers.3.ffns.0.layers.1.bias', 'norm.weight', 'norm.bias']

    opened by Simcs 3
  • How do we load ImageNet-21k ViT weights?

    How do we load ImageNet-21k ViT weights?

    Hi guys, thanks for open sourcing this repo!

    I see that your pretrained K600 models were initialized from the ViT ImageNet-21k weights. Can you share a snippet on how you initialized them? Did you use the models from timm?

    Thanks!

    opened by Darktex 3
  • AssertionError:When loading annotation,an assertion error for label appears

    AssertionError:When loading annotation,an assertion error for label appears

    I use maskfeat for pre-training according to the author's settings, and the train_data_path is k400_classmap.json. I didn't make any other changes to code. Then the following assertion error occurred.

    image image

    Thank you for your help!

    opened by yanmingchao2 3
  • the format of train_file.txt and loaded self-supervised pretrain checkpoints

    the format of train_file.txt and loaded self-supervised pretrain checkpoints

    Hi, can you provide train_file.txt, val_file.txt and test_file.txt? My email is [email protected] the way, I would like to share with you how to train model loaded with self-supervised pretrain checkpoints.

    opened by happy-hsy 3
  • error happened when I run dataset.py

    error happened when I run dataset.py

    error information:File "D:\anaconda3\envs\adarnn\lib\site-packages\torchvision\transforms\functional.py", line 494, in resized_crop assert _is_pil_image(img), 'img should be PIL Image' AssertionError: img should be PIL Image

    my configuration: win10,python3.7,torch 1.6.0, Your apply would be appreciated! Thank you very much!

    opened by nlpofwhat 2
  • Vivit Training Problem

    Vivit Training Problem

    First of all thank you for your excellent work! Let me talk about my configuration first. Set the model training hyperparameters according to the Training you gave. There are two main changes: changing the data set and using your Kinetics pre-training model. I am using the VGGSound dataset, which also splits the video into a sequence of RGB image frames as the dataset. The problem occurs in the model training phase. When using the pre-training model to initialize and train 1 epoch, the accuracy reaches 0.2, but the accuracy decreases as the training progresses. 2022-07-04 18:43:18 - Evaluating mean top1_acc:0.213, top5_acc:0.427 of current training epoch
    2022-07-04 18:48:55 - Evaluating mean top1_acc:0.171, top5_acc:0.360 of current validation epoch
    2022-07-04 21:08:07 - Evaluating mean top1_acc:0.197, top5_acc:0.430 of current training epoch
    2022-07-04 21:12:59 - Evaluating mean top1_acc:0.071, top5_acc:0.202 of current validation epoch
    2022-07-04 23:30:01 - Evaluating mean top1_acc:0.059, top5_acc:0.175 of current training epoch
    2022-07-04 23:34:57 - Evaluating mean top1_acc:0.027, top5_acc:0.089 of current validation epoch
    2022-07-05 01:46:54 - Evaluating mean top1_acc:0.029, top5_acc:0.102 of current training epoch
    2022-07-05 01:51:35 - Evaluating mean top1_acc:0.017, top5_acc:0.060 of current validation epoch
    2022-07-05 03:42:59 - Evaluating mean top1_acc:0.026, top5_acc:0.092 of current training epoch
    2022-07-05 03:47:38 - Evaluating mean top1_acc:0.016, top5_acc:0.056 of current validation epoch
    2022-07-05 05:42:18 - Evaluating mean top1_acc:0.027, top5_acc:0.096 of current training epoch
    2022-07-05 05:46:48 - Evaluating mean top1_acc:0.013, top5_acc:0.054 of current validation epoch
    2022-07-05 07:35:56 - Evaluating mean top1_acc:0.028, top5_acc:0.096 of current training epoch
    2022-07-05 07:40:33 - Evaluating mean top1_acc:0.017, top5_acc:0.063 of current validation epoch
    2022-07-05 09:32:25 - Evaluating mean top1_acc:0.028, top5_acc:0.099 of current training epoch
    2022-07-05 09:37:00 - Evaluating mean top1_acc:0.017, top5_acc:0.066 of current validation epoch
    2022-07-05 11:28:31 - Evaluating mean top1_acc:0.029, top5_acc:0.101 of current training epoch
    2022-07-05 11:33:02 - Evaluating mean top1_acc:0.017, top5_acc:0.062 of current validation epoch

    opened by muzhaohui 1
  • AttributeError: 'VideoTransformer' object has no attribute 'weight_decay'

    AttributeError: 'VideoTransformer' object has no attribute 'weight_decay'

    I got this error until I changed the following model_trainer.py:

    param_group["weight_decay"] = self._get_momentum(base_value=self.weight_decay, final_value=self.configs.weight_decay_end)

    to

    param_group["weight_decay"] = self._get_momentum(base_value=self.configs.weight_decay, final_value=self.configs.weight_decay_end)

    opened by Enclavet 1
  • build_finetune_optimizer raise NotImplementedError

    build_finetune_optimizer raise NotImplementedError

    why build_finetune_optimizer raise NotImplementedError if hparams.arch is not mvit? I use the training command in README to finune ViViT

    def build_finetune_optimizer(hparams, model):
    	if hparams.arch == 'mvit':
    		if hparams.layer_decay == 1:
    			get_layer_func = None
    			scales = None
    		else:
    			num_layers = 16
    			get_layer_func = partial(get_mvit_layer, num_layers=num_layers + 2)
    			scales = list(hparams.layer_decay ** i for i in reversed(range(num_layers + 2)))
    	else:
    		raise NotImplementedError
    
    opened by aries-young 1
  • Maskfeat downstream task performance

    Maskfeat downstream task performance

    I tried to finetune a classifier with the maskfeat pretrained weights you provided, but the final performance was terrible (UCF101 Acc@top1=52%). What is your performance with finetune maskfeat? and what are your mvit finetune settings?

    opened by WHlTE-N0lSE 1
  • Log-File for ViViT finetuning with Imagenet pre-train Weights

    Log-File for ViViT finetuning with Imagenet pre-train Weights

    Hi @mx-mark Do you have a log file for experiment of ViViT fine-tuning with Imagenet-21k pre-train weights?

    I am referring to following experiment:

    python model_pretrain.py -lr 0.005 -epoch 30 -batch_size 8 -num_workers 4 -num_frames 16 -frame_interval 16 -num_class 400 \ -arch 'vivit' -attention_type 'fact_encoder' -optim_type 'sgd' -lr_schedule 'cosine' \ -objective 'supervised' -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH \ -val_data_path $VAL_DATA_PATH -pretrain_pth $PRETRAIN_WEIGHTS -weights_from 'imagenet'

    opened by asif-hanif 0
  • Question about Loading a pretrained model(ViT)

    Question about Loading a pretrained model(ViT)

    Hello thanks for your works. i have a simple question. i downloaded a pretrained weight(ViT) from google research github. and i just wanna know that how can i recognize my vivit model initialized successfully from pretrained weight(ViT).

    opened by wonkicho 0
  • Errors when loading pretrained weights -pretrain_pth 'vivit_model.pth' -weights_from 'kinetics'

    Errors when loading pretrained weights -pretrain_pth 'vivit_model.pth' -weights_from 'kinetics'

    When I want to finetune my dataset based on pretrained kinetics vivit model, the errors occured. I am new to pytorch, may I know How could solve the following errors? Thanks.

    command

    python model_pretrain.py \
    	-lr 0.001 -epoch 100 -batch_size 32 -num_workers 4  -frame_interval 16  \
    	-arch 'vivit' -attention_type 'fact_encoder' -optim_type 'sgd' -lr_schedule 'cosine' \
    	-objective 'supervised' -root_dir ./ \
        -gpus 0 -num_class 2 -img_size 50 -num_frames 13 \
        -warmup_epochs 5 \
        -pretrain_pth 'vivit_model.pth' -weights_from 'kinetics'
    

    Errors:

    RuntimeError: Error(s) in loading state_dict for ViViT:
    File "/home/VideoTransformer-pytorch/weight_init.py", line 319, in init_from_kinetics_pretrain_
        msg = module.load_state_dict(state_dict, strict=False)
      File "/home/anaconda3/envs/pytorchvideo/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
        self.__class__.__name__, "\n\t".join(error_msgs)))
            size mismatch for pos_embed: copying a param with shape torch.Size([1, 197, 768]) from checkpoint, the shape in current model is torch.Size([1, 10, 768]).
            size mismatch for time_embed: copying a param with shape torch.Size([1, 9, 768]) from checkpoint, the shape in current model is torch.Size([1, 7, 768]).
    
    opened by desti-nation 0
  • How to dataloader?

    How to dataloader?

    Hello, thank you very much for your outstanding work. I was new to computer vision, and I didn't see how the images were loaded into the model. Could you tell me how to extract 16 frames from the video and input them into the VIVIT model? Looking forward to your reply

    opened by SuperGentry 2
  • How to load Tensorflow checkpoints?

    How to load Tensorflow checkpoints?

    Hello, thanks for your great work. I have successfully trained the Vivit. However, only several checkpoints are available. In another issue, you have mentioned that the pre-trained models are from the original repo of Google. Could you kindly share the code for conversion or tell the method?

    opened by realgump 3
Owner
Xin Ma
Xin Ma
Automatically creates genre collections for your Plex media

Plex Auto Genres Plex Auto Genres is a simple script that will add genre collection tags to your media making it much easier to search for genre speci

Shane Israel 63 Dec 31, 2022
NeRD: Neural Reflectance Decomposition from Image Collections

NeRD: Neural Reflectance Decomposition from Image Collections Project Page | Video | Paper | Dataset Implementation for NeRD. A novel method which dec

Computergraphics (University of Tübingen) 195 Dec 29, 2022
EMNLP 2021 Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections

Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections Ruiqi Zhong, Kristy Lee*, Zheng Zhang*, Dan Klein EMN

Ruiqi Zhong 42 Nov 3, 2022
A Blender python script for getting asset browser custom preview images for objects and collections.

asset_snapshot A Blender python script for getting asset browser custom preview images for objects and collections. Installation: Click the code butto

Johnny Matthews 44 Nov 29, 2022
Collections for the lasted paper about multi-view clustering methods (papers, codes)

Multi-View Clustering Papers Collections for the lasted paper about multi-view clustering methods (papers, codes). There also exists some repositories

Andrew Guan 10 Sep 20, 2022
Video-Captioning - A machine Learning project to generate captions for video frames indicating the relationship between the objects in the video

Video-Captioning - A machine Learning project to generate captions for video frames indicating the relationship between the objects in the video

null 1 Jan 23, 2022
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
Code and model benchmarks for "SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology"

NeurIPS 2020 SEVIR Code for paper: SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology Requirement

USAF - MIT Artificial Intelligence Accelerator 46 Dec 15, 2022
"NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search".

NAS-Bench-301 This repository containts code for the paper: "NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search". The

AutoML-Freiburg-Hannover 57 Nov 30, 2022
Benchmarks for semi-supervised domain generalization.

Semi-Supervised Domain Generalization This code is the official implementation of the following paper: Semi-Supervised Domain Generalization with Stoc

Kaiyang 49 Dec 10, 2022
Sequence modeling benchmarks and temporal convolutional networks

Sequence Modeling Benchmarks and Temporal Convolutional Networks (TCN) This repository contains the experiments done in the work An Empirical Evaluati

CMU Locus Lab 3.5k Jan 1, 2023
Source code and notebooks to reproduce experiments and benchmarks on Bias Faces in the Wild (BFW).

Face Recognition: Too Bias, or Not Too Bias? Robinson, Joseph P., Gennady Livitz, Yann Henon, Can Qin, Yun Fu, and Samson Timoner. "Face recognition:

Joseph P. Robinson 41 Dec 12, 2022
NeurIPS 2021 Datasets and Benchmarks Track

AP-10K: A Benchmark for Animal Pose Estimation in the Wild Introduction | Updates | Overview | Download | Training Code | Key Questions | License Intr

AP-10K 82 Dec 11, 2022
Training code and evaluation benchmarks for the "Self-Supervised Policy Adaptation during Deployment" paper.

Self-Supervised Policy Adaptation during Deployment PyTorch implementation of PAD and evaluation benchmarks from Self-Supervised Policy Adaptation dur

Nicklas Hansen 101 Nov 1, 2022
Benchmarks for the Optimal Power Flow Problem

Power Grid Lib - Optimal Power Flow This benchmark library is curated and maintained by the IEEE PES Task Force on Benchmarks for Validation of Emergi

A Library of IEEE PES Power Grid Benchmarks 207 Dec 8, 2022
Benchmark spaces - Benchmarks of how well different two dimensional spaces work for clustering algorithms

benchmark_spaces Benchmarks of how well different two dimensional spaces work fo

Bram Cohen 6 May 7, 2022
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic video-to-video translation.

vid2vid Project | YouTube(short) | YouTube(full) | arXiv | Paper(full) Pytorch implementation for high-resolution (e.g., 2048x1024) photorealistic vid

NVIDIA Corporation 8.1k Jan 1, 2023
[CVPR 2022] Official PyTorch Implementation for "Reference-based Video Super-Resolution Using Multi-Camera Video Triplets"

Reference-based Video Super-Resolution (RefVSR) Official PyTorch Implementation of the CVPR 2022 Paper Project | arXiv | RealMCVSR Dataset This repo c

Junyong Lee 151 Dec 30, 2022