[ICCV21] Official implementation of the "Social NCE: Contrastive Learning of Socially-aware Motion Representations" in PyTorch.

Overview

Social-NCE + CrowdNav

Website | Paper | Video | Social NCE + Trajectron | Social NCE + STGCNN

This is an official implementation for
Social NCE: Contrastive Learning of Socially-aware Motion Representations
Yuejiang Liu, Qi Yan, Alexandre Alahi, ICCV 2021

TL;DR: Contrastive Representation Learning + Negative Data Augmentations 🡲 Robust Neural Motion Models

Preparation

Setup environments follwoing the SETUP.md

Training & Evaluation

  • Behavioral Cloning (Vanilla)
    python imitate.py --contrast_weight=0.0 --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-baseline-data-0.50/policy_net.pth
    
  • Social-NCE + Conventional Negative Sampling (Local)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='local' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-local-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0-range-2.00/policy_net.pth
    
  • Social-NCE + Safety-driven Negative Sampling (Ours)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='event' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth
    
  • Method Comparison
    bash script/run_vanilla.sh && bash script/run_local.sh && bash script/run_snce.sh
    python utils/compare.py
    

Basic Results

Results of behavioral cloning with different methods.

Averaged results from the 150th to 200th epochs.

collision reward
Vanilla 12.7% ± 3.8% 0.274 ± 0.019
Local 19.3% ± 4.2% 0.240 ± 0.021
Ours 2.0% ± 0.6% 0.331 ± 0.003

Citation

If you find this code useful for your research, please cite our papers:

@article{liu2020snce,
  title   = {Social NCE: Contrastive Learning of Socially-aware Motion Representations},
  author  = {Yuejiang Liu and Qi Yan and Alexandre Alahi},
  journal = {arXiv preprint arXiv:2012.11717},
  year    = {2020}
}
@inproceedings{chen2019crowdnav,
    title={Crowd-Robot Interaction: Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning},
    author={Changan Chen and Yuejiang Liu and Sven Kreiss and Alexandre Alahi},
    year={2019},
    booktitle={ICRA}
}
Comments
  • What does SAIL refer to?

    What does SAIL refer to?

    Thanks for your wonderful repo! May I ask what does SAIL policy refer to? Based on the paper I only find SARL but not SAIL. Is SAIL an upgrade version of SARL? Thanks!

    opened by chenzhutian 9
  • Question regarding training

    Question regarding training

    Hello,

    first of all, thank you for the publicly available code! I have 2 understanding questions regarding the training:

    1. as far as I understand you determine the positive and negative samples based on the ground truth at time t, with the horizon of e.g. 5 time steps. What about the possible collisions before and after the one point?
    2. What is the ground truth in reinforcement learning? Do you use here a linear model?

    Best regards

    opened by Mirorrn 4
  • some questions about training error

    some questions about training error

    hello, thanks for you available code. i met error like this when i run "python imitate.py --contrast_weight=2.0 --contrast_sampling='local' --gpu" in crowd_nav path, the error information is: RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc) it seems that the attention's linear operation got the wrong shapes, what should i do? thanks :)

    opened by wanglaotou 2
  • What does 'sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))' mean

    What does 'sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))' mean

    Thanks for your hard work.

    I have some questions about this code for alpha in alpha_list: sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha)) Looks like α* neg_ seeds + (1- α)* pos_ seeds What is the purpose of this code

    from sampling.py

        # primary-neighbor boundary
        if self.num_boundary > 0:
            alpha_list = torch.linspace(self.ratio_boundary, 1.0, steps=self.num_boundary) #0.5-1 分成0-9份 参数决定
            sample_boundary = []
            for alpha in alpha_list:
                sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))
            sample_boundary = torch.cat(sample_boundary, axis=2)
            sample_neg = torch.cat([sample_boundary, sample_territory], axis=2)
        else:
            sample_neg = sample_territory
    

    Thanks!

    opened by jsor2009 2
  • Training on Trajnet++

    Training on Trajnet++

    Hi I am about to try your code on the Trajnet++.

    I am trying to get the FDE score of 1.14 .

    Did you train on the whole (with cff) training dataset? How many epochs? And with which parameters?

    If I reach your performance I will delete the submission if desired.

    Thanks in advance Many greetings

    opened by Mirorrn 2
  • Error when download data

    Error when download data

    When I download data, pip install gdown && gdown https://drive.google.com/uc?id=1D2guAxD_EgrKnJFMcLSBkf10SOagz0mr, it occurs that: Traceback (most recent call last): File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 169, in _new_conn conn = connection.create_connection( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 96, in create_connection raise err File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 86, in create_connection sock.connect(sa) OSError: [Errno 101] Network is unreachable

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 699, in urlopen httplib_response = self._make_request( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 382, in _make_request self._validate_conn(conn) File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 1010, in _validate_conn conn.connect() File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 353, in connect conn = self._new_conn() File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 181, in _new_conn raise NewConnectionError( urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 439, in send resp = conn.urlopen( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 755, in urlopen retries = retries.increment( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/retry.py", line 573, in increment raise MaxRetryError(_pool, url, error or ResponseError(cause)) urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=1awXDsRQcmgacj7nUhPzwb5UMNZCJCjvu (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable'))

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/lyl/miniconda3/bin/gdown", line 8, in sys.exit(main()) File "/home/lyl/miniconda3/lib/python3.8/site-packages/gdown/cli.py", line 95, in main download( File "/home/lyl/miniconda3/lib/python3.8/site-packages/gdown/download.py", line 77, in download res = sess.get(url, stream=True) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 555, in get return self.request('GET', url, **kwargs) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 542, in request resp = self.send(prep, **send_kwargs) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 655, in send r = adapter.send(request, **kwargs) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 516, in send raise ConnectionError(e, request=request) requests.exceptions.ConnectionError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=1awXDsRQcmgacj7nUhPzwb5UMNZCJCjvu (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable'))

    Can you fix it?Thanks.

    opened by lylyjy 2
  • I cannot find

    I cannot find "Imitator" in crowd_nav.utils.trainer

    Thank you for your hard work. in the file "crowd_nav/utils/demonstrate.py"

       from crowd_nav.utils.frames import FrameStack
      from crowd_nav.utils.trainer import Imitator
      from crowd_nav.utils.memory import ReplayMemory
    
    

    I cannot find "Imitator" in crowd_nav.utils.trainer. will you help me ?

    opened by yyf17 1
  • Some issues in test with square and mixed scenario

    Some issues in test with square and mixed scenario

    Hello, me and my team are trying to analyze your works(both of crowdnav and social-nce) as term project. While I run test.py I face two problems. 1st. In square scenario, there is too large decreasing of performance. WHY?

    !python test.py --policy='sail' --square --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth (skip) 2021-12-05 08:55:59, INFO: TEST success: 0.58, collision: 0.41, nav time: 10.43, reward: 0.1095 +- 0.2644 2021-12-05 08:55:59, INFO: Frequency of being in danger: 1.23

    2nd. In mixed scenario, (I add argument in parser), error is occur Traceback (most recent call last):

    File "test.py", line 129, in main() File "test.py", line 126, in main explorer.run_k_episodes(env.case_size[args.phase], args.phase, print_failure=False) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/utils/explorer.py", line 60, in run_k_episodes action = self.robot.act(ob) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_sim/envs/utils/robot.py", line 13, in act action = self.policy.predict(state) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/policy/sail.py", line 123, in predict action = self.model(self.last_state[0], self.last_state[1])[0].squeeze() File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/policy/sail.py", line 73, in forward human_state = self.transform.transform_frame(crowd_obsv) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/utils/transform.py", line 14, in transform_frame state = torch.cat([frame, relative], axis=2) RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 5 for tensor number 1 in the list.

    If you already have knew about those problem, please share the answers. Thank you for reading

    opened by LynAlpha 2
Owner
VITA lab at EPFL
Visual Intelligence for Transportation
VITA lab at EPFL
[ICCV21] Self-Calibrating Neural Radiance Fields

Self-Calibrating Neural Radiance Fields, ICCV, 2021 Project Page | Paper | Video Author Information Yoonwoo Jeong [Google Scholar] Seokjun Ahn [Google

null 381 Dec 30, 2022
[ICCV21] Code for RetrievalFuse: Neural 3D Scene Reconstruction with a Database

RetrievalFuse Paper | Project Page | Video RetrievalFuse: Neural 3D Scene Reconstruction with a Database Yawar Siddiqui, Justus Thies, Fangchang Ma, Q

Yawar Nihal Siddiqui 75 Dec 22, 2022
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 6, 2022
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

null 49 Nov 23, 2022
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes.

NVIDIA Research Projects 3.2k Dec 30, 2022
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 6, 2022
Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

HiSD: Image-to-image Translation via Hierarchical Style Disentanglement Official pytorch implementation of paper "Image-to-image Translation

null 364 Dec 14, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official PyTorch Implementation of Unsupervised Learning of Scene Flow Estimation Fusing with Local Rigidity

UnRigidFlow This is the official PyTorch implementation of UnRigidFlow (IJCAI2019). Here are two sample results (~10MB gif for each) of our unsupervis

Liang Liu 28 Nov 16, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
An official implementation of "SFNet: Learning Object-aware Semantic Correspondence" (CVPR 2019, TPAMI 2020) in PyTorch.

PyTorch implementation of SFNet This is the implementation of the paper "SFNet: Learning Object-aware Semantic Correspondence". For more information,

CV Lab @ Yonsei University 87 Dec 30, 2022
Old Photo Restoration (Official PyTorch Implementation)

Bringing Old Photo Back to Life (CVPR 2020 oral)

Microsoft 11.3k Dec 30, 2022
Official PyTorch implementation of Spatial Dependency Networks.

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling Đorđe Miladinović   Aleksandar Stanić   Stefan Bauer   Jürgen Schmid

Djordje Miladinovic 34 Jan 19, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 217 Jan 3, 2023
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 3, 2023
StyleGAN2-ADA - Official PyTorch implementation

Need Help? If you’re new to StyleGAN2-ADA and looking to get started, please check out this video series from a course Lia Coleman and I taught in Oct

Derrick Schultz 217 Jan 4, 2023
Official PyTorch implementation of "ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows"

ArtFlow Official PyTorch implementation of the paper: ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows Jie An*, Siyu Huang*, Yibing

null 123 Dec 27, 2022
Official PyTorch implementation of RobustNet (CVPR 2021 Oral)

RobustNet (CVPR 2021 Oral): Official Project Webpage Codes and pretrained models will be released soon. This repository provides the official PyTorch

Sungha Choi 173 Dec 21, 2022
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Hila Chefer 489 Jan 7, 2023