DCVNet: Dilated Cost Volumes for Fast Optical Flow
This repository contains our implementation of the paper:
@InProceedings{jiang2021dcvnet,
title={DCVNet: Dilated Cost Volumes for Fast Optical Flow},
author={Jiang, Huaizu and Learned-Miller, Erik},
booktitle={arXiv},
year={2021}
}
Need a fast optical flow model? Try DCVNet
- Fast. On a mid-end GTX 1080ti GPU, DCVNet runs in real time at 71 fps (frames-per-second) to process images with sizes of 1024 × 436.
- Compact and accurate. DCVNet has 4.94M parameters and consumes 1.68GB GPU memory during inference. It achieves comparable accuracy to state-of-the-art approaches on the MPI Sintel benchmark.
Requirements
This code has been tested with Python 3.7, PyTorch 1.6.0, and CUDA 9.2. We suggest to use a conda environment.
conda create -n dcvnet
conda activate dcvnet
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboardX scipy opencv -c pytorch
pip install yacs
We use an open-source implementation https://github.com/ClementPinard/Pytorch-Correlation-extension to compute dilated cost volumes. Follow the instructions there to install this module.
Demos
Pretrained models can be downloaded by running
./scripts/download_models.sh
or downloaded from Google drive.
You can demo a pre-trained model on a sequence of frames
python demo.py --weights-path pretrained_models/sceneflow_dcvnet.pth --path demo-frames
Required data
The following datasets are required to train and evaluate DCVNet.
- FlyingThings3D_subset (SceneFlow)
- Monkaa (SceneFlow)
- Driving (SceneFlow)
- MPI_Sintel
- KITTI2012
- KITTI2015
- HD1K (optional)
We borrow the data loaders used in RAFT. By default, dcvnet/data/raft/datasets.py
will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the datasets
folder
|-- datasets
|-- Driving
|-- frames_cleanpass
|-- optical_flow
|-- FlyingThings3D_subset
|-- train
|-- flow
|-- image_clean
|-- val
|-- flow
|-- image_clean
|-- Monkaa
|-- frames_cleanpass
|-- optical_flow
|-- MPI_Sintel
|-- test
|-- training
|-- KITTI2012
|-- testing
|-- training
|-- KITTI2015
|-- testing
|-- training
|-- HD1K
|-- hd1k_flow_gt
|-- hd1k_input
Evaluation
You can evaluate a pre-trained model using tools/evaluate_optical_flow.py
python evaluate_optical_flow.py --weights_path models/dcvnet-sceneflow.pth --dataset sintel
You can optionally add the --amp
switch to do inference in mixed precision to reduce GPU memory usage.
Training
We used 8 GTX 1080ti GPUs for training. Training logs will be written to the output
folder, which can be visualized using tensorboard.
# train on the synthetic scene flow dataset
python tools/train_optical_flow.py --config-file configs/sceneflow_dcvnet.yaml
# fine-tune it on the MPI-Sintel dataset
# 4 GPUs are sufficient, but here we use 8 GPUs for fast training
python tools/train_optical_flow.py --config-file configs/sintel_dcvnet.yaml --pretrain-weights output/SceneFlow/sceneflow_dcvnet/default/train_epoch_50.pth
# fine-tune it on the KITTI 2012 and 2015 dataset
# we only use 6 GPUs (3 GPUs are sufficient) since the batch size is 6
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python tools/train_optical_flow.py --config-file configs/kitti12+15_dcvnet.yaml --pretrain-weights output/Sintel+SceneFlow/sintel_dcvnet/default/train_epoch_5.pth
Note on the inference speed
In the main
branch, the computation of the dilated cost volumes can be further optimized without using the for loop. Checkout the efficient
branch for details. If you are interested in testing the inference speed, we suggest to switch to the efficient
branch.
git checkout efficient
CUDA_VISIBLE_DEVICES=0 python tools/evaluate_optical_flow.py --dry-run
We haven't fixed this problem because our pre-trained models are based on the implementation in the main
branch, which are not compatible with the resizing in the efficient
branch. We need to re-train all our models. It will be fixed soon.
To-do
- Fix the problem of efficient cost volume computation.
- Train the model on the AutoFlow dataset.
Acknowledgment
Our implementation is built on top of RAFT, Pytorch-Correlation-extension, yacs, Detectron2, and semseg. We thank the authors for releasing and maintaining the code.