A new codebase for Group Activity Recognition. It contains codes for ICCV 2021 paper: Spatio-Temporal Dynamic Inference Network for Group Activity Recognition and some other methods.

Overview

Spatio-Temporal Dynamic Inference Network for Group Activity Recognition

The source codes for ICCV2021 Paper: Spatio-Temporal Dynamic Inference Network for Group Activity Recognition.
[paper] [supplemental material] [arXiv]

If you find our work or the codebase inspiring and useful to your research, please cite

@inproceedings{yuan2021DIN,
  title={Spatio-Temporal Dynamic Inference Network for Group Activity Recognition},
  author={Yuan, Hangjie and Ni, Dong and Wang, Mang},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={7476--7485},
  year={2021}
}

Dependencies

  • Software Environment: Linux (CentOS 7)
  • Hardware Environment: NVIDIA TITAN RTX
  • Python 3.6
  • PyTorch 1.2.0, Torchvision 0.4.0
  • RoIAlign for Pytorch

Prepare Datasets

  1. Download publicly available datasets from following links: Volleyball dataset and Collective Activity dataset.
  2. Unzip the dataset file into data/volleyball or data/collective.
  3. Download the file tracks_normalized.pkl from cvlab-epfl/social-scene-understanding and put it into data/volleyball/videos

Using Docker

  1. Checkout repository and cd PROJECT_PATH

  2. Build the Docker container

docker build -t din_gar https://github.com/JacobYuan7/DIN_GAR.git#main
  1. Run the Docker container
docker run --shm-size=2G -v data/volleyball:/opt/DIN_GAR/data/volleyball -v result:/opt/DIN_GAR/result --rm -it din_gar
  • --shm-size=2G: To prevent ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm)., you have to extend the container's shared memory size. Alternatively: --ipc=host
  • -v data/volleyball:/opt/DIN_GAR/data/volleyball: Makes the host's folder data/volleyball available inside the container at /opt/DIN_GAR/data/volleyball
  • -v result:/opt/DIN_GAR/result: Makes the host's folder result available inside the container at /opt/DIN_GAR/result
  • -it & --rm: Starts the container with an interactive session (PROJECT_PATH is /opt/DIN_GAR) and removes the container after closing the session.
  • din_gar the name/tag of the image
  • optional: --gpus='"device=7"' restrict the GPU devices the container can access.

Get Started

  1. Train the Base Model: Fine-tune the base model for the dataset.

    # Volleyball dataset
    cd PROJECT_PATH 
    python scripts/train_volleyball_stage1.py
    
    # Collective Activity dataset
    cd PROJECT_PATH 
    python scripts/train_collective_stage1.py
  2. Train with the reasoning module: Append the reasoning modules onto the base model to get a reasoning model.

    1. Volleyball dataset

      • DIN

        python scripts/train_volleyball_stage2_dynamic.py
        
      • lite DIN
        We can run DIN in lite version by setting cfg.lite_dim = 128 in scripts/train_volleyball_stage2_dynamic.py.

        python scripts/train_volleyball_stage2_dynamic.py
        
      • ST-factorized DIN
        We can run ST-factorized DIN by setting cfg.ST_kernel_size = [(1,3),(3,1)] and cfg.hierarchical_inference = True.

        Note that if you set cfg.hierarchical_inference = False, cfg.ST_kernel_size = [(1,3),(3,1)] and cfg.num_DIN = 2, then multiple interaction fields run in parallel.

        python scripts/train_volleyball_stage2_dynamic.py
        

      Other model re-implemented by us according to their papers or publicly available codes:

      • AT
        python scripts/train_volleyball_stage2_at.py
        
      • PCTDM
        python scripts/train_volleyball_stage2_pctdm.py
        
      • SACRF
        python scripts/train_volleyball_stage2_sacrf_biute.py
        
      • ARG
        python scripts/train_volleyball_stage2_arg.py
        
      • HiGCIN
        python scripts/train_volleyball_stage2_higcin.py
        
    2. Collective Activity dataset

      • DIN
        python scripts/train_collective_stage2_dynamic.py
        
      • DIN lite
        We can run DIN in lite version by setting 'cfg.lite_dim = 128' in 'scripts/train_collective_stage2_dynamic.py'.
        python scripts/train_collective_stage2_dynamic.py
        

Another work done by us, solving GAR from the perspective of incorporating visual context, is also available.

@inproceedings{yuan2021visualcontext,
  title={Learning Visual Context for Group Activity Recognition},
  author={Yuan, Hangjie and Ni, Dong},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={35},
  number={4},
  pages={3261--3269},
  year={2021}
}
Comments
  • Was the validation set used for training?

    Was the validation set used for training?

    Thanks for sharing the code! However, according to the code, it seems that the validation set of the volleyball dataset was used as part of the training set.

    https://github.com/cvlab-epfl/social-scene-understanding/blob/b26080785f6ef1c1e0d08ab496744e5e8c1c8d4c/volleyball.py#L4 V.S. https://github.com/JacobYuan7/DIN_GAR/blob/7c5a8f54cc3b713ead438cc47dd54c491c837fbf/config.py#L28

    May I ask what is the reason or justification for this? Did any prior effort do the same thing?

    opened by hongluzhou 17
  • 请问MCA和MPCA指标是否有详细说明?

    请问MCA和MPCA指标是否有详细说明?

    作者你好,感谢你精彩的工作。 我在读论文的时候尝试寻找MCA和MPCA指标的相关信息

    请问MCA就是代码benchmark中注释部分展示的群体行为识别准确率吗? 请问想要得到接近论文的准确率超参数应该如何设置? 另外MPCA指标应该从什么地方找到指标介绍呢?

    另外是否可以讲解一下#Params参数量计算的方法?

    opened by Kev1n3zz 11
  • Stage 2: FileNotFoundError basemodel_VD_res18.pth & basemodel_VD_vgg16.pth

    Stage 2: FileNotFoundError basemodel_VD_res18.pth & basemodel_VD_vgg16.pth

    Thanks for sharing your code!

    train_volleyball_stage1.py works on my machine, but any train_volleyball_stage2_* script throws FileNotFoundError - either result/basemodel_VD_res18.pth or result/basemodel_VD_vgg16.pth could not be found.

    opened by daniel-richter 9
  • collective数据集第二阶段训练,模型载入报错

    collective数据集第二阶段训练,模型载入报错

    您好,感谢您开源的代码。VD的一二阶段训练都是顺利的,但是我在训练CAD时,第一阶段使用resnet18训练,但是在第二阶段遇到了权重载入的错误,模型配置文件完全是默认未改动的,您可以帮我看看这是为什么吗

        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    RuntimeError: Error(s) in loading state_dict for MyRes18:
            Missing key(s) in state_dict: "features.0.weight", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.0.conv1.weight", "features.4.0.bn1.weight", "features.4.0.bn1.bias", "features.4.0.bn1.running_mean", "features.4.0.bn1.running_var", "features.4.0.conv2.weight", "features.4.0.bn2.weight", "features.4.0.bn2.bias", "features.4.0.bn2.running_mean", "features.4.0.bn2.running_var", "features.4.1.conv1.weight", "features.4.1.bn1.weight", "features.4.1.bn1.bias", "features.4.1.bn1.running_mean", "features.4.1.bn1.running_var", "features.4.1.conv2.weight", "features.4.1.bn2.weight", "features.4.1.bn2.bias", "features.4.1.bn2.running_mean", "features.4.1.bn2.running_var", "features.5.0.conv1.weight", "features.5.0.bn1.weight", "features.5.0.bn1.bias", "features.5.0.bn1.running_mean", "features.5.0.bn1.running_var", "features.5.0.conv2.weight", "features.5.0.bn2.weight", "features.5.0.bn2.bias", "features.5.0.bn2.running_mean", "features.5.0.bn2.running_var", "features.5.0.downsample.0.weight", "features.5.0.downsample.1.weight", "features.5.0.downsample.1.bias", "features.5.0.downsample.1.running_mean", "features.5.0.downsample.1.running_var", "features.5.1.conv1.weight", "features.5.1.bn1.weight", "features.5.1.bn1.bias", "features.5.1.bn1.running_mean", "features.5.1.bn1.running_var", "features.5.1.conv2.weight", "features.5.1.bn2.weight", "features.5.1.bn2.bias", "features.5.1.bn2.running_mean", "features.5.1.bn2.running_var", "features.6.0.conv1.weight", "features.6.0.bn1.weight", "features.6.0.bn1.bias", "features.6.0.bn1.running_mean", "features.6.0.bn1.running_var", "features.6.0.conv2.weight", "features.6.0.bn2.weight", "features.6.0.bn2.bias", "features.6.0.bn2.running_mean", "features.6.0.bn2.running_var", "features.6.0.downsample.0.weight", "features.6.0.downsample.1.weight", "features.6.0.downsample.1.bias", "features.6.0.downsample.1.running_mean", "features.6.0.downsample.1.running_var", "features.6.1.conv1.weight", "features.6.1.bn1.weight", "features.6.1.bn1.bias", "features.6.1.bn1.running_mean", "features.6.1.bn1.running_var", "features.6.1.conv2.weight", "features.6.1.bn2.weight", "features.6.1.bn2.bias", "features.6.1.bn2.running_mean", "features.6.1.bn2.running_var", "features.7.0.conv1.weight", "features.7.0.bn1.weight", "features.7.0.bn1.bias", "features.7.0.bn1.running_mean", "features.7.0.bn1.running_var", "features.7.0.conv2.weight", "features.7.0.bn2.weight", "features.7.0.bn2.bias", "features.7.0.bn2.running_mean", "features.7.0.bn2.running_var", "features.7.0.downsample.0.weight", "features.7.0.downsample.1.weight", "features.7.0.downsample.1.bias", "features.7.0.downsample.1.running_mean", "features.7.0.downsample.1.running_var", "features.7.1.conv1.weight", "features.7.1.bn1.weight", "features.7.1.bn1.bias", "features.7.1.bn1.running_mean", "features.7.1.bn1.running_var", "features.7.1.conv2.weight", "features.7.1.bn2.weight", "features.7.1.bn2.bias", "features.7.1.bn2.running_mean", "features.7.1.bn2.running_var". 
            Unexpected key(s) in state_dict: "Conv2d_1a_3x3.conv.weight", "Conv2d_1a_3x3.bn.weight", "Conv2d_1a_3x3.bn.bias", "Conv2d_1a_3x3.bn.running_mean", "Conv2d_1a_3x3.bn.running_var", "Conv2d_1a_3x3.bn.num_batches_tracked", "Conv2d_2a_3x3.conv.weight", "Conv2d_2a_3x3.bn.weight", "Conv2d_2a_3x3.bn.bias", "Conv2d_2a_3x3.bn.running_mean", "Conv2d_2a_3x3.bn.running_var", "Conv2d_2a_3x3.bn.num_batches_tracked", "Conv2d_2b_3x3.conv.weight", "Conv2d_2b_3x3.bn.weight", "Conv2d_2b_3x3.bn.bias", "Conv2d_2b_3x3.bn.running_mean", "Conv2d_2b_3x3.bn.running_var", "Conv2d_2b_3x3.bn.num_batches_tracked", "Conv2d_3b_1x1.conv.weight", "Conv2d_3b_1x1.bn.weight", "Conv2d_3b_1x1.bn.bias", "Conv2d_3b_1x1.bn.running_mean", "Conv2d_3b_1x1.bn.running_var", "Conv2d_3b_1x1.bn.num_batches_tracked", "Conv2d_4a_3x3.conv.weight", "Conv2d_4a_3x3.bn.weight", "Conv2d_4a_3x3.bn.bias", "Conv2d_4a_3x3.bn.running_mean", "Conv2d_4a_3x3.bn.running_var", "Conv2d_4a_3x3.bn.num_batches_tracked", "Mixed_5b.branch1x1.conv.weight", "Mixed_5b.branch1x1.bn.weight", "Mixed_5b.branch1x1.bn.bias", "Mixed_5b.branch1x1.bn.running_mean", "Mixed_5b.branch1x1.bn.running_var", "Mixed_5b.branch1x1.bn.num_batches_tracked", "Mixed_5b.branch5x5_1.conv.weight", "Mixed_5b.branch5x5_1.bn.weight", "Mixed_5b.branch5x5_1.bn.bias", "Mixed_5b.branch5x5_1.bn.running_mean", "Mixed_5b.branch5x5_1.bn.running_var", "Mixed_5b.branch5x5_1.bn.num_batches_tracked", "Mixed_5b.branch5x5_2.conv.weight", "Mixed_5b.branch5x5_2.bn.weight", "Mixed_5b.branch5x5_2.bn.bias", "Mixed_5b.branch5x5_2.bn.running_mean", "Mixed_5b.branch5x5_2.bn.running_var", "Mixed_5b.branch5x5_2.bn.num_batches_tracked", "Mixed_5b.branch3x3dbl_1.conv.weight", "Mixed_5b.branch3x3dbl_1.bn.weight", "Mixed_5b.branch3x3dbl_1.bn.bias", "Mixed_5b.branch3x3dbl_1.bn.running_mean", "Mixed_5b.branch3x3dbl_1.bn.running_var", "Mixed_5b.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_5b.branch3x3dbl_2.conv.weight", "Mixed_5b.branch3x3dbl_2.bn.weight", "Mixed_5b.branch3x3dbl_2.bn.bias", "Mixed_5b.branch3x3dbl_2.bn.running_mean", "Mixed_5b.branch3x3dbl_2.bn.running_var", "Mixed_5b.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_5b.branch3x3dbl_3.conv.weight", "Mixed_5b.branch3x3dbl_3.bn.weight", "Mixed_5b.branch3x3dbl_3.bn.bias", "Mixed_5b.branch3x3dbl_3.bn.running_mean", "Mixed_5b.branch3x3dbl_3.bn.running_var", "Mixed_5b.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_5b.branch_pool.conv.weight", "Mixed_5b.branch_pool.bn.weight", "Mixed_5b.branch_pool.bn.bias", "Mixed_5b.branch_pool.bn.running_mean", "Mixed_5b.branch_pool.bn.running_var", "Mixed_5b.branch_pool.bn.num_batches_tracked", "Mixed_5c.branch1x1.conv.weight", "Mixed_5c.branch1x1.bn.weight", "Mixed_5c.branch1x1.bn.bias", "Mixed_5c.branch1x1.bn.running_mean", "Mixed_5c.branch1x1.bn.running_var", "Mixed_5c.branch1x1.bn.num_batches_tracked", "Mixed_5c.branch5x5_1.conv.weight", "Mixed_5c.branch5x5_1.bn.weight", "Mixed_5c.branch5x5_1.bn.bias", "Mixed_5c.branch5x5_1.bn.running_mean", "Mixed_5c.branch5x5_1.bn.running_var", "Mixed_5c.branch5x5_1.bn.num_batches_tracked", "Mixed_5c.branch5x5_2.conv.weight", "Mixed_5c.branch5x5_2.bn.weight", "Mixed_5c.branch5x5_2.bn.bias", "Mixed_5c.branch5x5_2.bn.running_mean", "Mixed_5c.branch5x5_2.bn.running_var", "Mixed_5c.branch5x5_2.bn.num_batches_tracked", "Mixed_5c.branch3x3dbl_1.conv.weight", "Mixed_5c.branch3x3dbl_1.bn.weight", "Mixed_5c.branch3x3dbl_1.bn.bias", "Mixed_5c.branch3x3dbl_1.bn.running_mean", "Mixed_5c.branch3x3dbl_1.bn.running_var", "Mixed_5c.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_5c.branch3x3dbl_2.conv.weight", "Mixed_5c.branch3x3dbl_2.bn.weight", "Mixed_5c.branch3x3dbl_2.bn.bias", "Mixed_5c.branch3x3dbl_2.bn.running_mean", "Mixed_5c.branch3x3dbl_2.bn.running_var", "Mixed_5c.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_5c.branch3x3dbl_3.conv.weight", "Mixed_5c.branch3x3dbl_3.bn.weight", "Mixed_5c.branch3x3dbl_3.bn.bias", "Mixed_5c.branch3x3dbl_3.bn.running_mean", "Mixed_5c.branch3x3dbl_3.bn.running_var", "Mixed_5c.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_5c.branch_pool.conv.weight", "Mixed_5c.branch_pool.bn.weight", "Mixed_5c.branch_pool.bn.bias", "Mixed_5c.branch_pool.bn.running_mean", "Mixed_5c.branch_pool.bn.running_var", "Mixed_5c.branch_pool.bn.num_batches_tracked", "Mixed_5d.branch1x1.conv.weight", "Mixed_5d.branch1x1.bn.weight", "Mixed_5d.branch1x1.bn.bias", "Mixed_5d.branch1x1.bn.running_mean", "Mixed_5d.branch1x1.bn.running_var", "Mixed_5d.branch1x1.bn.num_batches_tracked", "Mixed_5d.branch5x5_1.conv.weight", "Mixed_5d.branch5x5_1.bn.weight", "Mixed_5d.branch5x5_1.bn.bias", "Mixed_5d.branch5x5_1.bn.running_mean", "Mixed_5d.branch5x5_1.bn.running_var", "Mixed_5d.branch5x5_1.bn.num_batches_tracked", "Mixed_5d.branch5x5_2.conv.weight", "Mixed_5d.branch5x5_2.bn.weight", "Mixed_5d.branch5x5_2.bn.bias", "Mixed_5d.branch5x5_2.bn.running_mean", "Mixed_5d.branch5x5_2.bn.running_var", "Mixed_5d.branch5x5_2.bn.num_batches_tracked", "Mixed_5d.branch3x3dbl_1.conv.weight", "Mixed_5d.branch3x3dbl_1.bn.weight", "Mixed_5d.branch3x3dbl_1.bn.bias", "Mixed_5d.branch3x3dbl_1.bn.running_mean", "Mixed_5d.branch3x3dbl_1.bn.running_var", "Mixed_5d.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_5d.branch3x3dbl_2.conv.weight", "Mixed_5d.branch3x3dbl_2.bn.weight", "Mixed_5d.branch3x3dbl_2.bn.bias", "Mixed_5d.branch3x3dbl_2.bn.running_mean", "Mixed_5d.branch3x3dbl_2.bn.running_var", "Mixed_5d.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_5d.branch3x3dbl_3.conv.weight", "Mixed_5d.branch3x3dbl_3.bn.weight", "Mixed_5d.branch3x3dbl_3.bn.bias", "Mixed_5d.branch3x3dbl_3.bn.running_mean", "Mixed_5d.branch3x3dbl_3.bn.running_var", "Mixed_5d.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_5d.branch_pool.conv.weight", "Mixed_5d.branch_pool.bn.weight", "Mixed_5d.branch_pool.bn.bias", "Mixed_5d.branch_pool.bn.running_mean", "Mixed_5d.branch_pool.bn.running_var", "Mixed_5d.branch_pool.bn.num_batches_tracked", "Mixed_6a.branch3x3.conv.weight", "Mixed_6a.branch3x3.bn.weight", "Mixed_6a.branch3x3.bn.bias", "Mixed_6a.branch3x3.bn.running_mean", "Mixed_6a.branch3x3.bn.running_var", "Mixed_6a.branch3x3.bn.num_batches_tracked", "Mixed_6a.branch3x3dbl_1.conv.weight", "Mixed_6a.branch3x3dbl_1.bn.weight", "Mixed_6a.branch3x3dbl_1.bn.bias", "Mixed_6a.branch3x3dbl_1.bn.running_mean", "Mixed_6a.branch3x3dbl_1.bn.running_var", "Mixed_6a.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_6a.branch3x3dbl_2.conv.weight", "Mixed_6a.branch3x3dbl_2.bn.weight", "Mixed_6a.branch3x3dbl_2.bn.bias", "Mixed_6a.branch3x3dbl_2.bn.running_mean", "Mixed_6a.branch3x3dbl_2.bn.running_var", "Mixed_6a.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_6a.branch3x3dbl_3.conv.weight", "Mixed_6a.branch3x3dbl_3.bn.weight", "Mixed_6a.branch3x3dbl_3.bn.bias", "Mixed_6a.branch3x3dbl_3.bn.running_mean", "Mixed_6a.branch3x3dbl_3.bn.running_var", "Mixed_6a.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_6b.branch1x1.conv.weight", "Mixed_6b.branch1x1.bn.weight", "Mixed_6b.branch1x1.bn.bias", "Mixed_6b.branch1x1.bn.running_mean", "Mixed_6b.branch1x1.bn.running_var", "Mixed_6b.branch1x1.bn.num_batches_tracked", "Mixed_6b.branch7x7_1.conv.weight", "Mixed_6b.branch7x7_1.bn.weight", "Mixed_6b.branch7x7_1.bn.bias", "Mixed_6b.branch7x7_1.bn.running_mean", "Mixed_6b.branch7x7_1.bn.running_var", "Mixed_6b.branch7x7_1.bn.num_batches_tracked", "Mixed_6b.branch7x7_2.conv.weight", "Mixed_6b.branch7x7_2.bn.weight", "Mixed_6b.branch7x7_2.bn.bias", "Mixed_6b.branch7x7_2.bn.running_mean", "Mixed_6b.branch7x7_2.bn.running_var", "Mixed_6b.branch7x7_2.bn.num_batches_tracked", "Mixed_6b.branch7x7_3.conv.weight", "Mixed_6b.branch7x7_3.bn.weight", "Mixed_6b.branch7x7_3.bn.bias", "Mixed_6b.branch7x7_3.bn.running_mean", "Mixed_6b.branch7x7_3.bn.running_var", "Mixed_6b.branch7x7_3.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_1.conv.weight", "Mixed_6b.branch7x7dbl_1.bn.weight", "Mixed_6b.branch7x7dbl_1.bn.bias", "Mixed_6b.branch7x7dbl_1.bn.running_mean", "Mixed_6b.branch7x7dbl_1.bn.running_var", "Mixed_6b.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_2.conv.weight", "Mixed_6b.branch7x7dbl_2.bn.weight", "Mixed_6b.branch7x7dbl_2.bn.bias", "Mixed_6b.branch7x7dbl_2.bn.running_mean", "Mixed_6b.branch7x7dbl_2.bn.running_var", "Mixed_6b.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_3.conv.weight", "Mixed_6b.branch7x7dbl_3.bn.weight", "Mixed_6b.branch7x7dbl_3.bn.bias", "Mixed_6b.branch7x7dbl_3.bn.running_mean", "Mixed_6b.branch7x7dbl_3.bn.running_var", "Mixed_6b.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_4.conv.weight", "Mixed_6b.branch7x7dbl_4.bn.weight", "Mixed_6b.branch7x7dbl_4.bn.bias", "Mixed_6b.branch7x7dbl_4.bn.running_mean", "Mixed_6b.branch7x7dbl_4.bn.running_var", "Mixed_6b.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_5.conv.weight", "Mixed_6b.branch7x7dbl_5.bn.weight", "Mixed_6b.branch7x7dbl_5.bn.bias", "Mixed_6b.branch7x7dbl_5.bn.running_mean", "Mixed_6b.branch7x7dbl_5.bn.running_var", "Mixed_6b.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6b.branch_pool.conv.weight", "Mixed_6b.branch_pool.bn.weight", "Mixed_6b.branch_pool.bn.bias", "Mixed_6b.branch_pool.bn.running_mean", "Mixed_6b.branch_pool.bn.running_var", "Mixed_6b.branch_pool.bn.num_batches_tracked", "Mixed_6c.branch1x1.conv.weight", "Mixed_6c.branch1x1.bn.weight", "Mixed_6c.branch1x1.bn.bias", "Mixed_6c.branch1x1.bn.running_mean", "Mixed_6c.branch1x1.bn.running_var", "Mixed_6c.branch1x1.bn.num_batches_tracked", "Mixed_6c.branch7x7_1.conv.weight", "Mixed_6c.branch7x7_1.bn.weight", "Mixed_6c.branch7x7_1.bn.bias", "Mixed_6c.branch7x7_1.bn.running_mean", "Mixed_6c.branch7x7_1.bn.running_var", "Mixed_6c.branch7x7_1.bn.num_batches_tracked", "Mixed_6c.branch7x7_2.conv.weight", "Mixed_6c.branch7x7_2.bn.weight", "Mixed_6c.branch7x7_2.bn.bias", "Mixed_6c.branch7x7_2.bn.running_mean", "Mixed_6c.branch7x7_2.bn.running_var", "Mixed_6c.branch7x7_2.bn.num_batches_tracked", "Mixed_6c.branch7x7_3.conv.weight", "Mixed_6c.branch7x7_3.bn.weight", "Mixed_6c.branch7x7_3.bn.bias", "Mixed_6c.branch7x7_3.bn.running_mean", "Mixed_6c.branch7x7_3.bn.running_var", "Mixed_6c.branch7x7_3.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_1.conv.weight", "Mixed_6c.branch7x7dbl_1.bn.weight", "Mixed_6c.branch7x7dbl_1.bn.bias", "Mixed_6c.branch7x7dbl_1.bn.running_mean", "Mixed_6c.branch7x7dbl_1.bn.running_var", "Mixed_6c.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_2.conv.weight", "Mixed_6c.branch7x7dbl_2.bn.weight", "Mixed_6c.branch7x7dbl_2.bn.bias", "Mixed_6c.branch7x7dbl_2.bn.running_mean", "Mixed_6c.branch7x7dbl_2.bn.running_var", "Mixed_6c.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_3.conv.weight", "Mixed_6c.branch7x7dbl_3.bn.weight", "Mixed_6c.branch7x7dbl_3.bn.bias", "Mixed_6c.branch7x7dbl_3.bn.running_mean", "Mixed_6c.branch7x7dbl_3.bn.running_var", "Mixed_6c.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_4.conv.weight", "Mixed_6c.branch7x7dbl_4.bn.weight", "Mixed_6c.branch7x7dbl_4.bn.bias", "Mixed_6c.branch7x7dbl_4.bn.running_mean", "Mixed_6c.branch7x7dbl_4.bn.running_var", "Mixed_6c.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_5.conv.weight", "Mixed_6c.branch7x7dbl_5.bn.weight", "Mixed_6c.branch7x7dbl_5.bn.bias", "Mixed_6c.branch7x7dbl_5.bn.running_mean", "Mixed_6c.branch7x7dbl_5.bn.running_var", "Mixed_6c.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6c.branch_pool.conv.weight", "Mixed_6c.branch_pool.bn.weight", "Mixed_6c.branch_pool.bn.bias", "Mixed_6c.branch_pool.bn.running_mean", "Mixed_6c.branch_pool.bn.running_var", "Mixed_6c.branch_pool.bn.num_batches_tracked", "Mixed_6d.branch1x1.conv.weight", "Mixed_6d.branch1x1.bn.weight", "Mixed_6d.branch1x1.bn.bias", "Mixed_6d.branch1x1.bn.running_mean", "Mixed_6d.branch1x1.bn.running_var", "Mixed_6d.branch1x1.bn.num_batches_tracked", "Mixed_6d.branch7x7_1.conv.weight", "Mixed_6d.branch7x7_1.bn.weight", "Mixed_6d.branch7x7_1.bn.bias", "Mixed_6d.branch7x7_1.bn.running_mean", "Mixed_6d.branch7x7_1.bn.running_var", "Mixed_6d.branch7x7_1.bn.num_batches_tracked", "Mixed_6d.branch7x7_2.conv.weight", "Mixed_6d.branch7x7_2.bn.weight", "Mixed_6d.branch7x7_2.bn.bias", "Mixed_6d.branch7x7_2.bn.running_mean", "Mixed_6d.branch7x7_2.bn.running_var", "Mixed_6d.branch7x7_2.bn.num_batches_tracked", "Mixed_6d.branch7x7_3.conv.weight", "Mixed_6d.branch7x7_3.bn.weight", "Mixed_6d.branch7x7_3.bn.bias", "Mixed_6d.branch7x7_3.bn.running_mean", "Mixed_6d.branch7x7_3.bn.running_var", "Mixed_6d.branch7x7_3.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_1.conv.weight", "Mixed_6d.branch7x7dbl_1.bn.weight", "Mixed_6d.branch7x7dbl_1.bn.bias", "Mixed_6d.branch7x7dbl_1.bn.running_mean", "Mixed_6d.branch7x7dbl_1.bn.running_var", "Mixed_6d.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_2.conv.weight", "Mixed_6d.branch7x7dbl_2.bn.weight", "Mixed_6d.branch7x7dbl_2.bn.bias", "Mixed_6d.branch7x7dbl_2.bn.running_mean", "Mixed_6d.branch7x7dbl_2.bn.running_var", "Mixed_6d.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_3.conv.weight", "Mixed_6d.branch7x7dbl_3.bn.weight", "Mixed_6d.branch7x7dbl_3.bn.bias", "Mixed_6d.branch7x7dbl_3.bn.running_mean", "Mixed_6d.branch7x7dbl_3.bn.running_var", "Mixed_6d.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_4.conv.weight", "Mixed_6d.branch7x7dbl_4.bn.weight", "Mixed_6d.branch7x7dbl_4.bn.bias", "Mixed_6d.branch7x7dbl_4.bn.running_mean", "Mixed_6d.branch7x7dbl_4.bn.running_var", "Mixed_6d.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_5.conv.weight", "Mixed_6d.branch7x7dbl_5.bn.weight", "Mixed_6d.branch7x7dbl_5.bn.bias", "Mixed_6d.branch7x7dbl_5.bn.running_mean", "Mixed_6d.branch7x7dbl_5.bn.running_var", "Mixed_6d.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6d.branch_pool.conv.weight", "Mixed_6d.branch_pool.bn.weight", "Mixed_6d.branch_pool.bn.bias", "Mixed_6d.branch_pool.bn.running_mean", "Mixed_6d.branch_pool.bn.running_var", "Mixed_6d.branch_pool.bn.num_batches_tracked", "Mixed_6e.branch1x1.conv.weight", "Mixed_6e.branch1x1.bn.weight", "Mixed_6e.branch1x1.bn.bias", "Mixed_6e.branch1x1.bn.running_mean", "Mixed_6e.branch1x1.bn.running_var", "Mixed_6e.branch1x1.bn.num_batches_tracked", "Mixed_6e.branch7x7_1.conv.weight", "Mixed_6e.branch7x7_1.bn.weight", "Mixed_6e.branch7x7_1.bn.bias", "Mixed_6e.branch7x7_1.bn.running_mean", "Mixed_6e.branch7x7_1.bn.running_var", "Mixed_6e.branch7x7_1.bn.num_batches_tracked", "Mixed_6e.branch7x7_2.conv.weight", "Mixed_6e.branch7x7_2.bn.weight", "Mixed_6e.branch7x7_2.bn.bias", "Mixed_6e.branch7x7_2.bn.running_mean", "Mixed_6e.branch7x7_2.bn.running_var", "Mixed_6e.branch7x7_2.bn.num_batches_tracked", "Mixed_6e.branch7x7_3.conv.weight", "Mixed_6e.branch7x7_3.bn.weight", "Mixed_6e.branch7x7_3.bn.bias", "Mixed_6e.branch7x7_3.bn.running_mean", "Mixed_6e.branch7x7_3.bn.running_var", "Mixed_6e.branch7x7_3.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_1.conv.weight", "Mixed_6e.branch7x7dbl_1.bn.weight", "Mixed_6e.branch7x7dbl_1.bn.bias", "Mixed_6e.branch7x7dbl_1.bn.running_mean", "Mixed_6e.branch7x7dbl_1.bn.running_var", "Mixed_6e.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_2.conv.weight", "Mixed_6e.branch7x7dbl_2.bn.weight", "Mixed_6e.branch7x7dbl_2.bn.bias", "Mixed_6e.branch7x7dbl_2.bn.running_mean", "Mixed_6e.branch7x7dbl_2.bn.running_var", "Mixed_6e.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_3.conv.weight", "Mixed_6e.branch7x7dbl_3.bn.weight", "Mixed_6e.branch7x7dbl_3.bn.bias", "Mixed_6e.branch7x7dbl_3.bn.running_mean", "Mixed_6e.branch7x7dbl_3.bn.running_var", "Mixed_6e.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_4.conv.weight", "Mixed_6e.branch7x7dbl_4.bn.weight", "Mixed_6e.branch7x7dbl_4.bn.bias", "Mixed_6e.branch7x7dbl_4.bn.running_mean", "Mixed_6e.branch7x7dbl_4.bn.running_var", "Mixed_6e.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_5.conv.weight", "Mixed_6e.branch7x7dbl_5.bn.weight", "Mixed_6e.branch7x7dbl_5.bn.bias", "Mixed_6e.branch7x7dbl_5.bn.running_mean", "Mixed_6e.branch7x7dbl_5.bn.running_var", "Mixed_6e.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6e.branch_pool.conv.weight", "Mixed_6e.branch_pool.bn.weight", "Mixed_6e.branch_pool.bn.bias", "Mixed_6e.branch_pool.bn.running_mean", "Mixed_6e.branch_pool.bn.running_var", "Mixed_6e.branch_pool.bn.num_batches_tracked". 
    
    opened by zengxunli 7
  • RuntimeError: Error(s) in loading state_dict for Linear:         size mismatch for weight: copying a param with shape torch.Size([1024, 26400]) from checkpoint, the shape in current model is torch.Size([1024, 12800]).

    RuntimeError: Error(s) in loading state_dict for Linear: size mismatch for weight: copying a param with shape torch.Size([1024, 26400]) from checkpoint, the shape in current model is torch.Size([1024, 12800]).

    Traceback (most recent call last): File "scripts/train_collective_stage2_dynamic.py", line 58, in train_net(cfg) File "./train_net_dynamic.py", line 83, in train_net model.loadmodel(cfg.stage1_model_path) File "./infer_model.py", line 1222, in loadmodel self.fc_emb_1.load_state_dict(state['fc_emb_state_dict']) File "/home/hi/.conda/envs/lj0.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Linear: size mismatch for weight: copying a param with shape torch.Size([1024, 26400]) from checkpoint, the shape in current model is torch.Size([1024, 12800]).

    opened by Zxl1556180489 5
  • file tracks_normalized.pkl issue

    file tracks_normalized.pkl issue

    I searched all papers about group activity recognition and their source code. Your team and other teams only use this file tracks_normalized.pkl and point out where it comes from. So we don't know how to normalize the track boxes. I would appreciate it if you could know this problem.

    opened by reacher1130 5
  • How to set up HRNet-w32 of the TCE inference module?

    How to set up HRNet-w32 of the TCE inference module?

    Hi, thanks for your sharing code of Learning Visual Context for Group Activity Recognition. But I get into a trouble with failing to from hrnet.init_hrnet import cls_hrnet_w32, pose_hrnet_w32 as shown in infer_module/TCE_STBiP_module.py. Could you give me some help to solve it?

    Additionally, I have another confusion. The construction and usage of HRNet-w32 in your code as follows:

    # consturction
    if pose_net == 'pose_hrnet_w32':
                self.encoder = pose_hrnet_w32(pretrained=True)
    
    
    # usage
    roi_image = roi_image.cuda()
    #print(roi_image.shape) #torch.Size([72, 3, 256, 192])
    roi_pose_feature = self.encoder(roi_image)
    

    As known, the HRNet-w32 output heatmaps at final_layer containing information of all keypoints. However, you get individual features of $d_e$-dimension from HRNet-w32 finally as mentioned in paper, so I want to know how you convert the heatmaps to a feature vector of $d_e$-dimension for $i$-th person(or bounding box) at certain frame.

    I hope to hear from you. Thank you in advance!

    Regard

    opened by 0shelter0 3
  • 'ARG_volleyball' object has no attribute 'fc_actions'

    'ARG_volleyball' object has no attribute 'fc_actions'

    When I try to execute python scripts/train_volleyball_stage2_arg.py, I get the following error:

    Load model states from:  result/basemodel_VD_vgg16.pth
    Traceback (most recent call last):
      File "scripts/train_volleyball_stage2_arg.py", line 40, in <module>
        train_net(cfg)
      File "./train_net_dynamic.py", line 112, in train_net
        test_info=test(validation_loader, model, device, 0, cfg)
      File "./train_net_dynamic.py", line 260, in test_volleyball
        ret = model((batch_data_test[0], batch_data_test[1]))
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "./infer_model.py", line 1003, in forward
        actions_scores = self.fc_actions(boxes_states_flat)  # B*T*N, actn_num
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
        type(self).__name__, name))
    AttributeError: 'ARG_volleyball' object has no attribute 'fc_actions'
    

    In class ARG_volleyball(nn.Module): the assignment for fc_actions is commented out. https://github.com/JacobYuan7/DIN_GAR/blob/7c5a8f54cc3b713ead438cc47dd54c491c837fbf/infer_model.py#L912-L913

    After making this line effective, the script works.

    opened by daniel-richter 1
  • Add Docker support

    Add Docker support

    I added a Dockerfile so you can run the experiments quickly within a container.

    • added Dockerfile
    • updated documentation (README.md)
    • removed user specific, absolute path from config.py
    • extended documentation for datset configuration https://github.com/JacobYuan7/DIN_GAR/issues/3
    opened by daniel-richter 1
  • Stage 1 (Volleyball): Missing file tracks_normalized.pkl

    Stage 1 (Volleyball): Missing file tracks_normalized.pkl

    For executing scripts/train_volleyball_stage1.py there is a missing file tracks_normalized.pkl. To get it working, I downloaded the file available at cvlab-epfl/social-scene-understanding:

    cd data/volleyball/videos
    wget https://raw.githubusercontent.com/cvlab-epfl/social-scene-understanding/master/data/volleyball/tracks_normalized.pkl
    
    opened by daniel-richter 1
  • Actual Code

    Actual Code

    Thanks for sharing your ICCV21' paper! Do you plan to release the actual code before the ICCV takes place? I would be very happy to get some insights into your code, as I think your paper would fit my use case very well.

    opened by daniel-richter 1
  • About the visualization

    About the visualization

    Hi Jacob, thanks for your great job! You draw the heatmaps of group interaction graph and person interaction graph in the paper (figure 5), and i notice that in dynamic_infer_module.py, you also write code to save relation_list for visualize, i wonder how your draw the heatmaps, could you share the scripts?

    opened by SatMa34 0
Owner
A Ph.D. candidate and a realistic idealist.
null
Implementation of temporal pooling methods studied in [ICIP'20] A Comparative Evaluation Of Temporal Pooling Methods For Blind Video Quality Assessment

Implementation of temporal pooling methods studied in [ICIP'20] A Comparative Evaluation Of Temporal Pooling Methods For Blind Video Quality Assessment

Zhengzhong Tu 5 Sep 16, 2022
Group Activity Recognition with Clustered Spatial Temporal Transformer

GroupFormer Group Activity Recognition with Clustered Spatial-TemporalTransformer Backbone Style Action Acc Activity Acc Config Download Inv3+flow+pos

null 28 Dec 12, 2022
Spontaneous Facial Micro Expression Recognition using 3D Spatio-Temporal Convolutional Neural Networks

Spontaneous Facial Micro Expression Recognition using 3D Spatio-Temporal Convolutional Neural Networks Abstract Facial expression recognition in video

Bogireddy Sai Prasanna Teja Reddy 103 Dec 29, 2022
ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021)

ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021) Project Page | Video | Paper | Data We present a novel metho

null 65 Nov 28, 2022
Implementation of the "PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences" paper.

PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences Introduction Point cloud sequences are irregular and unordered in the spatial dimen

Hehe Fan 63 Dec 9, 2022
Implementation of the "Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos" paper.

Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos Introduction Point cloud videos exhibit irregularities and lack of or

Hehe Fan 101 Dec 29, 2022
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Facebook Research 75 Dec 19, 2022
Learning Spatio-Temporal Transformer for Visual Tracking

STARK The official implementation of the paper Learning Spatio-Temporal Transformer for Visual Tracking Hiring research interns for visual transformer

Multimedia Research 484 Dec 29, 2022
Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach

Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach This is the implementation of traffic prediction code in DTMP based on PyTo

chenxin 1 Dec 19, 2021
DeepSTD: Mining Spatio-temporal Disturbances of Multiple Context Factors for Citywide Traffic Flow Prediction

DeepSTD: Mining Spatio-temporal Disturbances of Multiple Context Factors for Citywide Traffic Flow Prediction This is the implementation of DeepSTD in

null 5 Sep 26, 2022
Spatio-Temporal Entropy Model (STEM) for end-to-end leaned video compression.

Spatio-Temporal Entropy Model A Pytorch Reproduction of Spatio-Temporal Entropy Model (STEM) for end-to-end leaned video compression. More details can

null 16 Nov 28, 2022
Self-supervised spatio-spectro-temporal represenation learning for EEG analysis

EEG-Oriented Self-Supervised Learning and Cluster-Aware Adaptation This repository provides a tensorflow implementation of a submitted paper: EEG-Orie

Wonjun Ko 4 Jun 9, 2022
[CVPR 2022 Oral] TubeDETR: Spatio-Temporal Video Grounding with Transformers

TubeDETR: Spatio-Temporal Video Grounding with Transformers Website • STVG Demo • Paper This repository provides the code for our paper. This includes

Antoine Yang 108 Dec 27, 2022
Code for paper Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting

Decoupled Spatial-Temporal Graph Neural Networks Code for our paper: Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting.

S22 43 Jan 4, 2023
This repository contains a set of codes to run (i.e., train, perform inference with, evaluate) a diarization method called EEND-vector-clustering.

EEND-vector clustering The EEND-vector clustering (End-to-End-Neural-Diarization-vector clustering) is a speaker diarization framework that integrates

null 45 Dec 26, 2022
A general and strong 3D object detection codebase that supports more methods, datasets and tools (debugging, recording and analysis).

ALLINONE-Det ALLINONE-Det is a general and strong 3D object detection codebase built on OpenPCDet, which supports more methods, datasets and tools (de

Michael.CV 5 Nov 3, 2022
Source code for NAACL 2021 paper "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference"

TR-BERT Source code and dataset for "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference". The code is based on huggaface's transformers.

THUNLP 37 Oct 30, 2022
CVPR2021: Temporal Context Aggregation Network for Temporal Action Proposal Refinement

Temporal Context Aggregation Network - Pytorch This repo holds the pytorch-version codes of paper: "Temporal Context Aggregation Network for Temporal

Zhiwu Qing 63 Sep 27, 2022
Temporal Dynamic Convolutional Neural Network for Text-Independent Speaker Verification and Phonemetic Analysis

TDY-CNN for Text-Independent Speaker Verification Official implementation of Temporal Dynamic Convolutional Neural Network for Text-Independent Speake

Seong-Hu Kim 16 Oct 17, 2022