Official PyTorch Implementation of "AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting".

Overview

AgentFormer

This repo contains the official implementation of our paper:

AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting
Ye Yuan, Xinshuo Weng, Yanglan Ou, Kris Kitani
ICCV 2021
[website] [paper]

Overview

Loading AgentFormer Overview

Important Note

We have recently noticed a normalization bug in the code and after fixing it, the performance of our method is worse than the original numbers reported in the ICCV paper. For comparision, please use the correct numbers in the updated arXiv version.

Installation

Environment

  • Tested OS: MacOS, Linux
  • Python >= 3.7
  • PyTorch == 1.8.0

Dependencies:

  1. Install PyTorch 1.8.0 with the correct CUDA version.
  2. Install the dependencies:
    pip install -r requirements.txt
    

Datasets

  • For the ETH/UCY dataset, we already included a converted version compatible with our dataloader under datasets/eth_ucy.
  • For the nuScenes dataset, the following steps are required:
    1. Download the orignal nuScenes dataset. Checkout the instructions here.
    2. Follow the instructions of nuScenes prediction challenge. Download and install the map expansion.
    3. Run our script to obtain a processed version of the nuScenes dataset under datasets/nuscenes_pred:
      python data/process_nuscenes.py --data_root <PATH_TO_NUSCENES>
      

Pretrained Models

  • You can download pretrained models from Google Drive or BaiduYun (password: 9rvb) to reproduce the numbers in the paper.
  • Once the agentformer_models.zip file is downloaded, place it under the root folder of this repo and unzip it:
    unzip agentformer_models.zip
    
    This will place the models under the results folder. Note that the pretrained models directly correspond to the config files in cfg.

Evaluation

ETH/UCY

Run the following command to test pretrained models for the ETH dataset:

python test.py --cfg eth_agentformer --gpu 0

You can replace eth with {hotel, univ, zara1, zara2} to test other datasets in ETH/UCY. You should be able to get the numbers reported in the paper as shown in this table:

Ours ADE FDE
ETH 0.45 0.75
Hotel 0.14 0.22
Univ 0.25 0.45
Zara1 0.18 0.30
Zara2 0.14 0.24
Avg 0.23 0.39

nuScenes

Run the following command to test pretrained models for the nuScenes dataset:

python test.py --cfg nuscenes_5sample_agentformer --gpu 0

You can replace 5sample with 10sample to compute all the metrics (ADE_5, FDE_5, ADE_10, FDE_10). You should be able to get the numbers reported in the paper as shown in this table:

ADE_5 FDE_5 ADE_10 FDE_10
Ours 1.856 3.889 1.452 2.856

Training

You can train your own models with your customized configs. Here we take the ETH dataset as an example, but you can train models for other datasets with their corresponding configs. AgentFormer requires two-stage training:

  1. Train the AgentFormer VAE model (everything but the trajectory sampler):
    python train.py --cfg user_eth_agentformer_pre --gpu 0
    
  2. Once the VAE model is trained, train the AgentFormer DLow model (trajectory sampler):
    python train.py --cfg user_eth_agentformer --gpu 0
    
    Note that you need to change the pred_cfg field in user_eth_agentformer to the config you used in step 1 (user_eth_agentformer_pre) and change the pred_epoch to the VAE model epoch you want to use.

Citation

If you find our work useful in your research, please cite our paper AgentFormer:

@inproceedings{yuan2021agent,
  title={AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting},
  author={Yuan, Ye and Weng, Xinshuo and Ou, Yanglan and Kitani, Kris},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year={2021}
}

License

Please see the license for further details.

Comments
  • Nuscene Dataset Question

    Nuscene Dataset Question

    Hi,

    Thanks for your wonderful work on this paper, you guys did a good job!

    I have a question: did you use the training set obtained from all cameras on the Nuscenes dataset?

    Thanks ahead for your help!

    opened by zy1296 13
  • About normalization

    About normalization

    Hi, I have noticed that in the code (https://github.com/Khrylx/AgentFormer/blob/cf13e4033ceef7bdcfbf27183842415e7b841d34/model/agentformer.py#L535) the center of both the past trajectory and the future trajectory is used to normalize the input data. However, the future trajectory should not be available in the test. Is this data snooping? Please let me know if there is anything wrong with my understanding. Many thanks!

    opened by sjtuxcx 8
  • some puzzles about the math formulas in the CVAE Future Decoder part

    some puzzles about the math formulas in the CVAE Future Decoder part

    As you described in section 3.2 of the paper: image I can understand the purpose of the MSE term ||Y-\hat Y||^2 is to push the real value Y and the mean of the Guassian \hat Y as close as possible,because the Gaussian distribution has the maximum probability value at the mean value. But where did the weighting factor 1/(2beta) come from? Why dividing the variance by beta leads to this weighting factor?

    opened by ultimatedigiman 5
  • Add Tepper dataset dataloader

    Add Tepper dataset dataloader

    This PR supports loading from and training on the Tepper dataset. This can only be used for pedestrian forecasting in the same way as ETH/UCY; there are no pose annotations. In a future PR, we will clean up the dataloader infrastructure and support running experiments on other datasets in a cleaner fashion.

    opened by rccchoudhury 3
  • pred_epoch

    pred_epoch

    Hello @Khrylx , thank you for your great work.

    I didn't understand what is pred_epoch in cfg files for each dataset. Why it is different from a dataset to another in cfg files?

    Thanks

    opened by MZ82020 3
  • Questions about Visualization

    Questions about Visualization

    Hi @Khrylx ,

    Thank you for your work, it is really impressive.

    I am wondering if you have the visualisation script for these visualisation that you did. Screenshot 2021-11-12 at 10 38 27 PM Screenshot 2021-11-12 at 10 38 44 PM

    opened by jjbecomespheh 2
  • data processing

    data processing

    Hello, this is an excellent job, but I don't understand one question, I hope you can answer them. The first is why the eth/ UCY data processing needs to divid the scale: like "found_data = past_data[past_data[:, 1] == identity].squeeze()[[self.xind, self.zind]] / self.past_traj_scale". The self.past_traj_scale=2, but in Trajectron++, the dataset is not divide the scale.

    Looking forward to your reply.

    opened by JaneFo 2
  • Reg. distance between adjacent grid points in semantic maps for nuscenes

    Reg. distance between adjacent grid points in semantic maps for nuscenes

    Hi, congrats on the excellent work. I was going through your code and found that while converting the position of the agent in the image to pixel position you are multiplying it with a scale = 3. That should mean that the distance between adjacent pixels is 1/3 m as opposed to 3 m mentioned in the paper. Please let me know if I am interpreting anything wrong.

    opened by hrshl212 1
  • NAN value was obtained during training of the 1 sample trajectory sampler

    NAN value was obtained during training of the 1 sample trajectory sampler

    Hi

    I'm working on the nuScenes 1 sample training. After finishing 100 epochs of training of CVAE,I continued to train the trajectory sampler. But unfortulately I got NAN value during training on the diversity loss term. image

    I guess this is because the diversity loss term is divided by 0 when K=1,and It is meaningless to calculate the diversity loss when K=1 image

    Maybe we need to modify this line: https://github.com/Khrylx/AgentFormer/blob/195aae0e466327ffff4e698e9747f30de33683e4/model/dlow.py#L24

    In case I missed something: Have you ever encountered a NAN value when training trajectory sampler with K=1?

    opened by ultimatedigiman 1
  • ADE/FDE Future mask Loss

    ADE/FDE Future mask Loss

    Hello, I have seen that you are using a mask on the MSE loss to not take into consideration the padded agents, which is good. However, why aren't you applying the same on the ADE and FDE metrics?

    opened by linaashaji 1
  • velocity and heading

    velocity and heading

    There is no velocity and heading available in eth and ucy dataset .If it is present in nuscenes data than where i can find it ? Why we need columns with -1.0 ? Thanks for the response !

    opened by prakash-bisht 1
  • NotImplementedError in agent_aware_attention

    NotImplementedError in agent_aware_attention

    Thanks for your great work ! I am trying to apply your model to continue learning .Without modify your network and dataloader , i encounter following error .Since it raise this error inside of your lib function , and there are some forward pass already , i find it very hard to debug.... Could you please help me out?

    epoch:0,loss:mse: 20.561 (18.647) kld: 2.017 (2.326) sample: 20.089 (10.814) total_loss: 42.667 (31.787)
    epoch:0,loss:mse: 20.943 (37.872) kld: 2.045 (2.312) sample: 20.115 (24.269) total_loss: 43.103 (64.452)
    epoch:0,loss:mse: 21.206 (0.451) kld: 2.075 (2.000) sample: 20.142 (0.421) total_loss: 43.423 (2.872)
    epoch:0,loss:mse: 20.992 (4.669) kld: 2.075 (2.000) sample: 19.957 (4.490) total_loss: 43.024 (11.158)
    epoch:0,loss:mse: 21.082 (12.970) kld: 2.136 (6.585) sample: 20.000 (12.838) total_loss: 43.218 (32.393)
    epoch:0,loss:mse: 21.042 (16.777) kld: 2.331 (3.709) sample: 19.912 (16.551) total_loss: 43.286 (37.037)
    Traceback (most recent call last):
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./main.py", line 189, in <module>
        main()
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./main.py", line 162, in main
        task_iter(task, num_devices, pop, generation_id, loop_id, exp_config)
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/mu2net_traj/main.py", line 83, in task_iter
        train_loop(paths, ds_train, ds_validation,
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./mytrain.py", line 203, in train_loop
        model_data = path.model()
      File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 596, in forward
        self.inference(sample_num=self.loss_cfg['sample']['k'])
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 607, in inference
        self.future_decoder(self.data, mode=mode, sample_num=sample_num, autoregress=True, need_weights=need_weights)
      File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 425, in forward
        self.decode_traj_ar(data, mode, context, pre_motion, pre_vel, pre_motion_scene_norm, z, sample_num, need_weights=need_weights)
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 341, in decode_traj_ar
        tf_out, attn_weights = self.tf_decoder(tf_in_pos, context, memory_mask=mem_mask, tgt_mask=tgt_mask, num_agent=data['agent_num'], need_weights=need_weights)
      File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 746, in forward
        output, self_attn_weights[i], cross_attn_weights[i] = mod(output, memory, tgt_mask=tgt_mask,
      File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 644, in forward
        tgt2, self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
      File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 506, in forward
        return agent_aware_attention(
      File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 177, in agent_aware_attention
        raise NotImplementedError
    NotImplementedError
    
    opened by FUTUREEEEEE 0
  • Time encoder implementation

    Time encoder implementation

    Hi, I really like your work in dealing with multi-agent trajectories prediction. I went through the paper and codes and popped up a quick question about the time encoder. As you mentioned in the paper, the time encoder that integrated the timestamp features differs from the original positional encoder. But I cannot find the time encoder codes in this repo. Please let me know if I missed anything. Much appreciated! Screen Shot 2022-09-21 at 5 42 08 PM

    opened by Dennis-Tsai 0
  •  list index out of range

    list index out of range

    Hi brother. I have installed your package, and all the environments have been configured, but I still can't run it, and such an error is reported. This is the file format and wrong statement. image image image

    opened by cxk-0425 1
  • Issues understanding the input format for eth_ucy

    Issues understanding the input format for eth_ucy

    Dear reader, thank you for the great work on this topic and for realising the code for the community to improve. I am currently trying to understand the data format however am very unclear about how the eth_ucy dataset is actually preprocessed. Do I understand it correctly that only x and y coordinates are used and no velocity and heading inofmration is extracted. As looking through the data most columns only contain -1 values. Could you provide a column name list for the inputs found in the datasets/eth_ucy files?

    Furthermore, is my understadning correct that only the ego agent informatin is stored in a given row? As from reading the paper my understanding was that all agent states would be stored in a single entry for each timestep. Could you maybe elaborate on how you create the image representation of the data that is described in the paper? Thank you!

    opened by to314as 0
  • Most likely prediction for Nuscenes official prediction challenge

    Most likely prediction for Nuscenes official prediction challenge

    Hi,

    Thanks for sharing this great work!

    I have a question about this paper, have you tried to get the most likely prediction on Nuscenes official prediction challenge and calculate the ADE?

    Bo

    opened by BoLang615 1
Owner
Ye Yuan
PhD student at Robotics Institute, CMU
Ye Yuan
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes.

NVIDIA Research Projects 3.2k Dec 30, 2022
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 6, 2022
Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

HiSD: Image-to-image Translation via Hierarchical Style Disentanglement Official pytorch implementation of paper "Image-to-image Translation

null 364 Dec 14, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official PyTorch Implementation of Unsupervised Learning of Scene Flow Estimation Fusing with Local Rigidity

UnRigidFlow This is the official PyTorch implementation of UnRigidFlow (IJCAI2019). Here are two sample results (~10MB gif for each) of our unsupervis

Liang Liu 28 Nov 16, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
An official implementation of "SFNet: Learning Object-aware Semantic Correspondence" (CVPR 2019, TPAMI 2020) in PyTorch.

PyTorch implementation of SFNet This is the implementation of the paper "SFNet: Learning Object-aware Semantic Correspondence". For more information,

CV Lab @ Yonsei University 87 Dec 30, 2022
Old Photo Restoration (Official PyTorch Implementation)

Bringing Old Photo Back to Life (CVPR 2020 oral)

Microsoft 11.3k Dec 30, 2022
Official PyTorch implementation of Spatial Dependency Networks.

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling Đorđe Miladinović   Aleksandar Stanić   Stefan Bauer   Jürgen Schmid

Djordje Miladinovic 34 Jan 19, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 217 Jan 3, 2023
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 3, 2023
StyleGAN2-ADA - Official PyTorch implementation

Need Help? If you’re new to StyleGAN2-ADA and looking to get started, please check out this video series from a course Lia Coleman and I taught in Oct

Derrick Schultz 217 Jan 4, 2023
Official PyTorch implementation of "ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows"

ArtFlow Official PyTorch implementation of the paper: ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows Jie An*, Siyu Huang*, Yibing

null 123 Dec 27, 2022
Official PyTorch implementation of RobustNet (CVPR 2021 Oral)

RobustNet (CVPR 2021 Oral): Official Project Webpage Codes and pretrained models will be released soon. This repository provides the official PyTorch

Sungha Choi 173 Dec 21, 2022
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Hila Chefer 489 Jan 7, 2023
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

null 153 Dec 14, 2022
Official PyTorch implementation of MX-Font (Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Experts)

Introduction Pytorch implementation of Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Expert. | paper Song Park1

Clova AI Research 97 Dec 23, 2022
Official Pytorch implementation of 'GOCor: Bringing Globally Optimized Correspondence Volumes into Your Neural Network' (NeurIPS 2020)

Official implementation of GOCor This is the official implementation of our paper : GOCor: Bringing Globally Optimized Correspondence Volumes into You

Prune Truong 71 Nov 18, 2022
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Juhong Min 165 Dec 28, 2022