[ICCV'21] Official implementation for the paper Social NCE: Contrastive Learning of Socially-aware Motion Representations

Overview

CrowdNav with Social-NCE

This is an official implementation for the paper

Social NCE: Contrastive Learning of Socially-aware Motion Representations
by Yuejiang Liu, Qi Yan, Alexandre Alahi at EPFL
to appear at ICCV 2021

TL;DR: Contrastive Representation Learning + Negative Data Augmentations 🡲 Robust Neural Motion Models

Please check out our code for experiments on different models as follows:
Social NCE + CrowdNav | Social NCE + Trajectron | Social NCE + STGCNN

Preparation

Setup environments follwoing the SETUP.md

Training & Evaluation

  • Behavioral Cloning (Vanilla)
    python imitate.py --contrast_weight=0.0 --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-baseline-data-0.50/policy_net.pth
    
  • Social-NCE + Conventional Negative Sampling (Local)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='local' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-local-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0-range-2.00/policy_net.pth
    
  • Social-NCE + Safety-driven Negative Sampling (Ours)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='event' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth
    
  • Method Comparison
    bash script/run_vanilla.sh && bash script/run_local.sh && bash script/run_snce.sh
    python utils/compare.py
    

Basic Results

Results of behavioral cloning with different methods.

Averaged results from the 150th to 200th epochs.

collision reward
Vanilla 12.7% ± 3.8% 0.274 ± 0.019
Local 19.3% ± 4.2% 0.240 ± 0.021
Ours 2.0% ± 0.6% 0.331 ± 0.003

Citation

If you find this code useful for your research, please cite our papers:

@article{liu2020snce,
  title   = {Social NCE: Contrastive Learning of Socially-aware Motion Representations},
  author  = {Yuejiang Liu and Qi Yan and Alexandre Alahi},
  journal = {arXiv preprint arXiv:2012.11717},
  year    = {2020}
}
@inproceedings{chen2019crowdnav,
    title={Crowd-Robot Interaction: Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning},
    author={Changan Chen and Yuejiang Liu and Sven Kreiss and Alexandre Alahi},
    year={2019},
    booktitle={ICRA}
}
Comments
  • What does SAIL refer to?

    What does SAIL refer to?

    Thanks for your wonderful repo! May I ask what does SAIL policy refer to? Based on the paper I only find SARL but not SAIL. Is SAIL an upgrade version of SARL? Thanks!

    opened by chenzhutian 9
  • Question regarding training

    Question regarding training

    Hello,

    first of all, thank you for the publicly available code! I have 2 understanding questions regarding the training:

    1. as far as I understand you determine the positive and negative samples based on the ground truth at time t, with the horizon of e.g. 5 time steps. What about the possible collisions before and after the one point?
    2. What is the ground truth in reinforcement learning? Do you use here a linear model?

    Best regards

    opened by Mirorrn 4
  • some questions about training error

    some questions about training error

    hello, thanks for you available code. i met error like this when i run "python imitate.py --contrast_weight=2.0 --contrast_sampling='local' --gpu" in crowd_nav path, the error information is: RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc) it seems that the attention's linear operation got the wrong shapes, what should i do? thanks :)

    opened by wanglaotou 2
  • What does 'sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))' mean

    What does 'sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))' mean

    Thanks for your hard work.

    I have some questions about this code for alpha in alpha_list: sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha)) Looks like α* neg_ seeds + (1- α)* pos_ seeds What is the purpose of this code

    from sampling.py

        # primary-neighbor boundary
        if self.num_boundary > 0:
            alpha_list = torch.linspace(self.ratio_boundary, 1.0, steps=self.num_boundary) #0.5-1 分成0-9份 参数决定
            sample_boundary = []
            for alpha in alpha_list:
                sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))
            sample_boundary = torch.cat(sample_boundary, axis=2)
            sample_neg = torch.cat([sample_boundary, sample_territory], axis=2)
        else:
            sample_neg = sample_territory
    

    Thanks!

    opened by jsor2009 2
  • Training on Trajnet++

    Training on Trajnet++

    Hi I am about to try your code on the Trajnet++.

    I am trying to get the FDE score of 1.14 .

    Did you train on the whole (with cff) training dataset? How many epochs? And with which parameters?

    If I reach your performance I will delete the submission if desired.

    Thanks in advance Many greetings

    opened by Mirorrn 2
  • Error when download data

    Error when download data

    When I download data, pip install gdown && gdown https://drive.google.com/uc?id=1D2guAxD_EgrKnJFMcLSBkf10SOagz0mr, it occurs that: Traceback (most recent call last): File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 169, in _new_conn conn = connection.create_connection( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 96, in create_connection raise err File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 86, in create_connection sock.connect(sa) OSError: [Errno 101] Network is unreachable

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 699, in urlopen httplib_response = self._make_request( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 382, in _make_request self._validate_conn(conn) File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 1010, in _validate_conn conn.connect() File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 353, in connect conn = self._new_conn() File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 181, in _new_conn raise NewConnectionError( urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 439, in send resp = conn.urlopen( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 755, in urlopen retries = retries.increment( File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/retry.py", line 573, in increment raise MaxRetryError(_pool, url, error or ResponseError(cause)) urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=1awXDsRQcmgacj7nUhPzwb5UMNZCJCjvu (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable'))

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/lyl/miniconda3/bin/gdown", line 8, in sys.exit(main()) File "/home/lyl/miniconda3/lib/python3.8/site-packages/gdown/cli.py", line 95, in main download( File "/home/lyl/miniconda3/lib/python3.8/site-packages/gdown/download.py", line 77, in download res = sess.get(url, stream=True) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 555, in get return self.request('GET', url, **kwargs) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 542, in request resp = self.send(prep, **send_kwargs) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 655, in send r = adapter.send(request, **kwargs) File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 516, in send raise ConnectionError(e, request=request) requests.exceptions.ConnectionError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=1awXDsRQcmgacj7nUhPzwb5UMNZCJCjvu (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable'))

    Can you fix it?Thanks.

    opened by lylyjy 2
  • I cannot find

    I cannot find "Imitator" in crowd_nav.utils.trainer

    Thank you for your hard work. in the file "crowd_nav/utils/demonstrate.py"

       from crowd_nav.utils.frames import FrameStack
      from crowd_nav.utils.trainer import Imitator
      from crowd_nav.utils.memory import ReplayMemory
    
    

    I cannot find "Imitator" in crowd_nav.utils.trainer. will you help me ?

    opened by yyf17 1
  • Some issues in test with square and mixed scenario

    Some issues in test with square and mixed scenario

    Hello, me and my team are trying to analyze your works(both of crowdnav and social-nce) as term project. While I run test.py I face two problems. 1st. In square scenario, there is too large decreasing of performance. WHY?

    !python test.py --policy='sail' --square --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth (skip) 2021-12-05 08:55:59, INFO: TEST success: 0.58, collision: 0.41, nav time: 10.43, reward: 0.1095 +- 0.2644 2021-12-05 08:55:59, INFO: Frequency of being in danger: 1.23

    2nd. In mixed scenario, (I add argument in parser), error is occur Traceback (most recent call last):

    File "test.py", line 129, in main() File "test.py", line 126, in main explorer.run_k_episodes(env.case_size[args.phase], args.phase, print_failure=False) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/utils/explorer.py", line 60, in run_k_episodes action = self.robot.act(ob) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_sim/envs/utils/robot.py", line 13, in act action = self.policy.predict(state) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/policy/sail.py", line 123, in predict action = self.model(self.last_state[0], self.last_state[1])[0].squeeze() File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/policy/sail.py", line 73, in forward human_state = self.transform.transform_frame(crowd_obsv) File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/utils/transform.py", line 14, in transform_frame state = torch.cat([frame, relative], axis=2) RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 5 for tensor number 1 in the list.

    If you already have knew about those problem, please share the answers. Thank you for reading

    opened by LynAlpha 2
Owner
VITA lab at EPFL
Visual Intelligence for Transportation
VITA lab at EPFL
Code for the paper "Unsupervised Contrastive Learning of Sound Event Representations", ICASSP 2021.

Unsupervised Contrastive Learning of Sound Event Representations This repository contains the code for the following paper. If you use this code or pa

Eduardo Fonseca 81 Dec 22, 2022
This repository contains the code for the paper "Hierarchical Motion Understanding via Motion Programs"

Hierarchical Motion Understanding via Motion Programs (CVPR 2021) This repository contains the official implementation of: Hierarchical Motion Underst

Sumith Kulal 40 Dec 5, 2022
Code for Motion Representations for Articulated Animation paper

Motion Representations for Articulated Animation This repository contains the source code for the CVPR'2021 paper Motion Representations for Articulat

Snap Research 851 Jan 9, 2023
Hunt down social media accounts by username across social networks

Hunt down social media accounts by username across social networks Installation | Usage | Docker Notes | Contributing Installation # clone the repo $

null 1 Dec 14, 2021
Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

null 22 Sep 22, 2022
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Thalles Silva 1.7k Dec 28, 2022
Saeed Lotfi 28 Dec 12, 2022
Ranking Models in Unlabeled New Environments (iccv21)

Ranking Models in Unlabeled New Environments Prerequisites This code uses the following libraries Python 3.7 NumPy PyTorch 1.7.0 + torchivision 0.8.1

null 14 Dec 17, 2021
[ICCV21] Self-Calibrating Neural Radiance Fields

Self-Calibrating Neural Radiance Fields, ICCV, 2021 Project Page | Paper | Video Author Information Yoonwoo Jeong [Google Scholar] Seokjun Ahn [Google

null 381 Dec 30, 2022
[ICCV21] Code for RetrievalFuse: Neural 3D Scene Reconstruction with a Database

RetrievalFuse Paper | Project Page | Video RetrievalFuse: Neural 3D Scene Reconstruction with a Database Yawar Siddiqui, Justus Thies, Fangchang Ma, Q

Yawar Nihal Siddiqui 75 Dec 22, 2022
Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP

Wav2CLIP ?? WIP ?? Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP ?? ?? Ho-Hsiang Wu, Prem Seetharaman

Descript 240 Dec 13, 2022
CURL: Contrastive Unsupervised Representations for Reinforcement Learning

CURL Rainbow Status: Archive (code is provided as-is, no updates expected) This is an implementation of CURL: Contrastive Unsupervised Representations

Aravind Srinivas 46 Dec 12, 2022
Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations. [2021]

Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations This repo contains the Pytorch implementation of our paper: Revisit

Wouter Van Gansbeke 80 Nov 20, 2022
Supervised Contrastive Learning for Downstream Optimized Sequence Representations

SupCL-Seq ?? Supervised Contrastive Learning for Downstream Optimized Sequence representations (SupCS-Seq) accepted to be published in EMNLP 2021, ext

Hooman Sedghamiz 18 Oct 21, 2022
Exploring Versatile Prior for Human Motion via Motion Frequency Guidance (3DV2021)

Exploring Versatile Prior for Human Motion via Motion Frequency Guidance This is the codebase for video-based human motion reconstruction in human-mot

Jiachen Xu 5 Jul 14, 2022
Pop-Out Motion: 3D-Aware Image Deformation via Learning the Shape Laplacian (CVPR 2022)

Pop-Out Motion Pop-Out Motion: 3D-Aware Image Deformation via Learning the Shape Laplacian (CVPR 2022) Jihyun Lee*, Minhyuk Sung*, Hyunjin Kim, Tae-Ky

Jihyun Lee 88 Nov 22, 2022
Code release for BlockGAN: Learning 3D Object-aware Scene Representations from Unlabelled Images

BlockGAN Code release for BlockGAN: Learning 3D Object-aware Scene Representations from Unlabelled Images BlockGAN: Learning 3D Object-aware Scene Rep

null 41 May 18, 2022
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

Denis Emelin 42 Nov 24, 2022
Official implementation of the network presented in the paper "M4Depth: A motion-based approach for monocular depth estimation on video sequences"

M4Depth This is the reference TensorFlow implementation for training and testing depth estimation models using the method described in M4Depth: A moti

Michaël Fonder 76 Jan 3, 2023