(NeurIPS 2020) Wasserstein Distances for Stereo Disparity Estimation

Overview

Wasserstein Distances for Stereo Disparity Estimation

Accepted in NeurIPS 2020 as Spotlight. [Project Page]

Wasserstein Distances for Stereo Disparity Estimation

by Divyansh Garg, Yan Wang, Bharath Hariharan, Mark Campbell, Kilian Q. Weinberger and Wei-Lun Chao

Figure

Citation

@inproceedings{div2020wstereo,
  title={Wasserstein Distances for Stereo Disparity Estimation},
  author={Garg, Divyansh and Wang, Yan and Hariharan, Bharath and Campbell, Mark and Weinberger, Kilian and Chao, Wei-Lun},
  booktitle={NeurIPS},
  year={2020}
}

Introduction

Existing approaches to depth or disparity estimation output a distribution over a set of pre-defined discrete values. This leads to inaccurate results when the true depth or disparity does not match any of these values. The fact that this distribution is usually learned indirectly through a regression loss causes further problems in ambiguous regions around object boundaries. We address these issues using a new neural network architecture that is capable of outputting arbitrary depth values, and a new loss function that is derived from the Wasserstein distance between the true and the predicted distributions. We validate our approach on a variety of tasks, including stereo disparity and depth estimation, and the downstream 3D object detection. Our approach drastically reduces the error in ambiguous regions, especially around object boundaries that greatly affect the localization of objects in 3D, achieving the state-of-the-art in 3D object detection for autonomous driving.

Contents

Our Wasserstein loss modification W_loss can be easily plugged in existing stereo depth models to improve the training and obtain better results.

We release the code for CDN-PSMNet and CDN-SDN models.

Requirements

  1. Python 3.7
  2. Pytorch 1.2.0+
  3. CUDA
  4. pip install -r ./requirements.txt
  5. SceneFlow
  6. KITTI

Pretrained Models

TO BE ADDED.

Datasets

You have to download the SceneFlow and KITTI datasets. The structures of the datasets are shown in below.

SceneFlow Dataset Structure

SceneFlow
    | monkaa
        | frames_cleanpass
        | disparity
    | driving
        | frames_cleanpass
        | disparity
    | flyingthings3d
        | frames_cleanpass 
        | disparity

KITTI Object Detection Dataset Structure

KITTI
    | training
        | calib
        | image_2
        | image_3
        | velodyne
    | testing
        | calib
        | image_2
        | image_3

Generate soft-links of SceneFlow Datasets. The results will be saved in ./sceneflow folder. Please change to fakepath path-to-SceneFlow to the SceneFlow dataset location before running the script.

python sceneflow.py --path path-to-SceneFlow --force

Convert the KITTI velodyne ground truths to depth maps. Please change to fakepath path-to-KITTI to the SceneFlow dataset location before running the script.

python ./src/preprocess/generate_depth_map.py --data_path path-to-KITTI/ --split_file ./split/trainval.txt

Optionally download KITTI2015 datasets for evaluating stereo disparity models.

Training and Inference

We have provided all pretrained models Pretrained Models. If you only want to generate the predictions, you can directly go to step 3.

The default setting requires four gpus to train. You can use smaller batch sizes which are btrain and bval, if you don't have enough gpus.

We provide code for both stereo disparity and stereo depth models.

1 Train CDN-SDN from Scratch on SceneFlow Dataset

python ./src/main_depth.py -c src/configs/sceneflow_w1.config

The checkpoints are saved in ./results/stack_sceneflow_w1/.

Follow same procedure to train stereo disparity model, but use src/main_disp.py and change to a disparity config.

2 Train CDN-SDN on KITTI Dataset

python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --pretrain ./results/sceneflow_w1/checkpoint.pth.tar --dataset  path-to-KITTI/training/

Before running, please change the fakepath path-to-KITTI/ to the correct one. --pretrain is the path to the pretrained model on SceneFlow. The training results are saved in ./results/kitti_w1_train.

If you are working on evaluating CDN on KITTI testing set, you might want to train CDN on training+validation sets. The training results will be saved in ./results/sdn_kitti_trainval.

python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --pretrain ./results/sceneflow_w1/checkpoint.pth.tar \
    --dataset  path-to-KITTI/training/ --split_train ./split/trainval.txt \
    --save_path ./results/sdn_kitti_trainval

The disparity models can also be trained on KITTI2015 datasets using src/kitti2015_w1_disp.config.

3 Generate Predictions

Please change the fakepath path-to-KITTI. Moreover, if you use the our provided checkpoint, please modify the value of --resume to the checkpoint location.

  • a. Using the model trained on KITTI training set, and generating predictions on training + validation sets.
python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --resume ./results/sdn_kitti_train/checkpoint.pth.tar --datapath  path-to-KITTI/training/ \
    --data_list ./split/trainval.txt --generate_depth_map --data_tag trainval

The results will be saved in ./results/sdn_kitti_train/depth_maps_trainval/.

  • b. Using the model trained on KITTI training + validation set, and generating predictions on testing sets. You will use them when you want to submit your results to the leaderboard.

The results will be saved in ./results/sdn_kitti_trainval_set/depth_maps_trainval/.

# testing sets
python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --resume ./results/sdn_kitti_trainval/checkpoint.pth.tar --datapath  path-to-KITTI/testing/ \
    --data_list=./split/test.txt --generate_depth_map --data_tag test

The results will be saved in ./results/sdn_kitti_trainval/depth_maps_test/.

4 Train 3D Detection with Pseudo-LiDAR

For training 3D object detection models, follow step 4 and after in the Pseudo-LiDAR_V2 repo https://github.com/mileyan/Pseudo_Lidar_V2.

Results

Results on the Stereo Disparity

Figure

3D Object Detection Results on KITTI leader board

Figure

Questions

Please feel free to email us if you have any questions.

Divyansh Garg [email protected] Yan Wang [email protected] Wei-Lun Chao [email protected]

Comments
  • Question on pre-trained models

    Question on pre-trained models

    Hi, thank you for releasing the pre-trained depth models. Could you please advise if there is a script to re-produce CDN-DSGN in the current repo? Or are there instructions to do so?

    opened by Nicholasli1995 3
  • Question about multi-modal GT

    Question about multi-modal GT

    Hi, thank you for sharing the code and I have a few questions:

    1. Where can we find the generation of multi-modal ground truth (MM GT) as descripbed in the paper?
    2. Is there an implementation of Wasserstein Loss with MM GT as described by Eq. 12 in the paper?
    3. Did DSGN+CDN (the best 3D AP on KITTI) use MM GT or not? I'm using DSGN and trying to reproduce DSGN+CDN+Wasserstein Loss.
    opened by Nicholasli1995 1
  • Missing imports

    Missing imports

    Hi, I am trying to run the training myself using the instructions provided but there appear to be missing files.

    After downloading the SceneFlow dataset, if I try and run

    python ./src/main_depth.py -c src/configs/sceneflow_w1.config
    

    I get the following error:

    Traceback (most recent call last):
      File "./src/main_depth.py", line 21, in <module>
        import models
      File "C:\Code\I3DR\W-Stereo-Disp\src\models\__init__.py", line 1, in <module>
        from .full_res import PSMNet as basic
    ModuleNotFoundError: No module named 'models.full_res
    

    Looking in this file it seems that init.py is actually missing a lot of the imports

    from .full_res import PSMNet as basic
    from .stackhourglass import PSMNet as stackhourglass
    from .stackhourglass_classif import PSMNet as stackhourglass_classif
    from .stackhourglass_edge_aware import PSMNet as stackhourglass_edge_aware
    from .stackhourglass_full import PSMNet as stackhourglass_full
    from .stackhourglass_semantic import PSMNet as stackhourglass_semantic
    from .stackhourglass_softmax_offset import PSMNet as stackhourglass_softmax_offset
    from .stackhourglass_std import PSMNet as stackhourglass_std
    from .stackhourglass_volume import PSMNet as stackhourglass_volume
    from .stackhourglass_volume_large_off import PSMNet as stackhourglass_volume_large_off
    from .stackhourglass_volume_multihead import PSMNet as stackhourglass_multihead
    from .stackhourglass_volume_semantic import PSMNet as stackhourglass_volume_semantic
    from .stackhourglass_win import PSMNet as stackhourglass_win
    

    Most of these are missing. Should I have done something to generate these or have they been excluded?

    I tried commenting out the missing imports and then I found an API key is required for losswise. This isn't a service I have used before but I registered for an account and added my personal API key and this let the script continue.

    This seems to work and I am currently running the training. Was this the correct procedure?

    opened by benknight135 1
  • about kitti dataset image_3

    about kitti dataset image_3

    How did you use Image3 as a training set in your training?Because there is no official label for image3, there is only label2 corresponding to image2.If it is generated by transformation, can you tell me how to operate it?

    opened by JunjieChen-2020 0
  • Data loading problem

    Data loading problem

    When I use docker running on the server, I will encounter the following situation. I am stuck at this step. I still have not proceeded to the next step after two days. Reconfiguration of the environment and other methods are invalid. Can you help me?

    root@c860f179a9eb:~/smd/WDSDE/W-Stereo-Disp# python ./src/main_depth.py -c src/configs/kitti_w1.config --resume ./results/sdn_kitti_trainval/checkpoint.pth.tar --datapath ./KITTI/testing/ --data_list=./split/test.txt --generate_depth_map --data_tag test TPQAWUTNB [2021-08-02 05:34:23 main_depth.py:165] INFO api_key: TPQAWUTNB [2021-08-02 05:34:23 main_depth.py:165] INFO arch: stackhourglass_volume [2021-08-02 05:34:23 main_depth.py:165] INFO btrain: 12 [2021-08-02 05:34:23 main_depth.py:165] INFO bval: 4 [2021-08-02 05:34:23 main_depth.py:165] INFO calib_value: 1017 [2021-08-02 05:34:23 main_depth.py:165] INFO checkpoint_interval: -1 [2021-08-02 05:34:23 main_depth.py:165] INFO config: src/configs/kitti_w1.config [2021-08-02 05:34:23 main_depth.py:165] INFO data_list: ./split/test.txt [2021-08-02 05:34:23 main_depth.py:165] INFO data_tag: test [2021-08-02 05:34:23 main_depth.py:165] INFO data_type: depth [2021-08-02 05:34:23 main_depth.py:165] INFO datapath: ./KITTI/testing/ [2021-08-02 05:34:23 main_depth.py:165] INFO dataset: kitti [2021-08-02 05:34:23 main_depth.py:165] INFO depth_wise_loss: False [2021-08-02 05:34:23 main_depth.py:165] INFO down: 2 [2021-08-02 05:34:23 main_depth.py:165] INFO dynamic_bs: False [2021-08-02 05:34:23 main_depth.py:165] INFO epochs: 300 [2021-08-02 05:34:23 main_depth.py:165] INFO eval_interval: 50 [2021-08-02 05:34:23 main_depth.py:165] INFO evaluate: False [2021-08-02 05:34:23 main_depth.py:165] INFO generate_depth_map: True [2021-08-02 05:34:23 main_depth.py:165] INFO kitti2015: False [2021-08-02 05:34:23 main_depth.py:165] INFO losswise_tag: finetune_w1_fix [2021-08-02 05:34:23 main_depth.py:165] INFO lr: 0.001 [2021-08-02 05:34:23 main_depth.py:165] INFO lr_gamma: 0.1 [2021-08-02 05:34:23 main_depth.py:165] INFO lr_stepsize: [200] [2021-08-02 05:34:23 main_depth.py:165] INFO maxdepth: 80 [2021-08-02 05:34:23 main_depth.py:165] INFO maxdisp: 192 [2021-08-02 05:34:23 main_depth.py:165] INFO pretrain: ./results/checkpoint.pth.tar [2021-08-02 05:34:23 main_depth.py:165] INFO resume: ./results/sdn_kitti_trainval/checkpoint.pth.tar [2021-08-02 05:34:23 main_depth.py:165] INFO save_path: ./results/kitti_w1_train [2021-08-02 05:34:23 main_depth.py:165] INFO scale: 1 [2021-08-02 05:34:23 main_depth.py:165] INFO split_train: ./split/train.txt [2021-08-02 05:34:23 main_depth.py:165] INFO split_val: ./split/subval.txt [2021-08-02 05:34:23 main_depth.py:165] INFO start_epoch: 0 [2021-08-02 05:34:23 main_depth.py:165] INFO w_p: 1 [2021-08-02 05:34:23 main_depth.py:165] INFO warmup_epochs: 0 [2021-08-02 05:34:23 main_depth.py:209] INFO Number of model parameters: 5310496 [2021-08-02 05:34:28 main_depth.py:219] INFO => loading pretrain './results/checkpoint.pth.tar' [2021-08-02 05:34:28 main_depth.py:227] INFO => loading checkpoint './results/sdn_kitti_trainval/checkpoint.pth.tar' [2021-08-02 05:34:29 main_depth.py:235] INFO => loaded checkpoint './results/sdn_kitti_trainval/checkpoint.pth.tar' (epoch 300) 0%| | 0/1880 [00:00<?, ?it/s]

    opened by xhangHU 0
Owner
Divyansh Garg
Making robots intelligent
Divyansh Garg
Official respository for "Modeling Defocus-Disparity in Dual-Pixel Sensors", ICCP 2020

Official respository for "Modeling Defocus-Disparity in Dual-Pixel Sensors", ICCP 2020 BibTeX @INPROCEEDINGS{punnappurath2020modeling, author={Abhi

Abhijith Punnappurath 22 Oct 1, 2022
Python scripts form performing stereo depth estimation using the high res stereo model in PyTorch .

PyTorch-High-Res-Stereo-Depth-Estimation Python scripts form performing stereo depth estimation using the high res stereo model in PyTorch. Stereo dep

Ibai Gorordo 26 Nov 24, 2022
FADNet++: Real-Time and Accurate Disparity Estimation with Configurable Networks

FADNet++: Real-Time and Accurate Disparity Estimation with Configurable Networks

HKBU High Performance Machine Learning Lab 6 Nov 18, 2022
RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching

RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching This repository contains the source code for our paper: RAFT-Stereo: Multilevel

Princeton Vision & Learning Lab 328 Jan 9, 2023
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

Denis Emelin 42 Nov 24, 2022
Python and C++ implementation of "MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation". Accepted at LXCV @ CVPR 2021.

MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation This is a PyTorch and LibTorch implementation of MarkerPose: a

Jhacson Meza 47 Nov 18, 2022
Code for "Multi-View Multi-Person 3D Pose Estimation with Plane Sweep Stereo"

Multi-View Multi-Person 3D Pose Estimation with Plane Sweep Stereo This repository includes the source code for our CVPR 2021 paper on multi-view mult

Jiahao Lin 66 Jan 4, 2023
Python scripts for performing stereo depth estimation using the HITNET Tensorflow model.

HITNET-Stereo-Depth-estimation Python scripts for performing stereo depth estimation using the HITNET Tensorflow model from Google Research. Stereo de

Ibai Gorordo 76 Jan 2, 2023
Python scripts form performing stereo depth estimation using the CoEx model in ONNX.

ONNX-CoEx-Stereo-Depth-estimation Python scripts form performing stereo depth estimation using the CoEx model in ONNX. Stereo depth estimation on the

Ibai Gorordo 8 Dec 29, 2022
Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite.

TFLite-HITNET-Stereo-depth-estimation Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite. Stereo depth e

Ibai Gorordo 22 Oct 20, 2022
Python scripts form performing stereo depth estimation using the HITNET model in ONNX.

ONNX-HITNET-Stereo-Depth-estimation Python scripts form performing stereo depth estimation using the HITNET model in ONNX. Stereo depth estimation on

Ibai Gorordo 30 Nov 8, 2022
This repository contains the code for "SBEVNet: End-to-End Deep Stereo Layout Estimation" paper by Divam Gupta, Wei Pu, Trenton Tabor, Jeff Schneider

SBEVNet: End-to-End Deep Stereo Layout Estimation This repository contains the code for "SBEVNet: End-to-End Deep Stereo Layout Estimation" paper by D

Divam Gupta 19 Dec 17, 2022
Honours project, on creating a depth estimation map from two stereo images of featureless regions

image-processing This module generates depth maps for shape-blocked-out images Install If working with anaconda, then from the root directory: conda e

null 2 Oct 17, 2022
Python scripts for performing stereo depth estimation using the MobileStereoNet model in ONNX

ONNX-MobileStereoNet Python scripts for performing stereo depth estimation using the MobileStereoNet model in ONNX Stereo depth estimation on the cone

Ibai Gorordo 23 Nov 29, 2022
Python scripts for performing stereo depth estimation using the MobileStereoNet model in Tensorflow Lite.

TFLite-MobileStereoNet Python scripts for performing stereo depth estimation using the MobileStereoNet model in Tensorflow Lite. Stereo depth estimati

Ibai Gorordo 4 Feb 14, 2022
Distributional Sliced-Wasserstein distance code

Distributional Sliced Wasserstein distance This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Genera

VinAI Research 39 Jan 1, 2023
Implementation of Wasserstein adversarial attacks.

Stronger and Faster Wasserstein Adversarial Attacks Code for Stronger and Faster Wasserstein Adversarial Attacks, appeared in ICML 2020. This reposito

null 21 Oct 6, 2022
PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs

PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs This code aims to reproduce results obtained in the paper "Visual F

Orobix 93 Aug 17, 2022
Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN", accepted to ACM MM 2021 BNI Track.

RecycleD Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN

Yunan Zhu 23 Nov 5, 2022