Git Product home page Git Product logo

w-stereo-disp's Introduction

Wasserstein Distances for Stereo Disparity Estimation

[Project Page] Accepted in NeurIPS 2020 as Spotlight

Wasserstein Distances for Stereo Disparity Estimation

by Divyansh Garg, Yan Wang, Bharath Hariharan, Mark Campbell, Kilian Q. Weinberger and Wei-Lun Chao

Figure

Citation

@inproceedings{div2020wstereo,
  title={Wasserstein Distances for Stereo Disparity Estimation},
  author={Garg, Divyansh and Wang, Yan and Hariharan, Bharath and Campbell, Mark and Weinberger, Kilian and Chao, Wei-Lun},
  booktitle={NeurIPS},
  year={2020}
}

Introduction

Existing approaches to depth or disparity estimation output a distribution over a set of pre-defined discrete values. This leads to inaccurate results when the true depth or disparity does not match any of these values. The fact that this distribution is usually learned indirectly through a regression loss causes further problems in ambiguous regions around object boundaries. We address these issues using a new neural network architecture that is capable of outputting arbitrary depth values, and a new loss function that is derived from the Wasserstein distance between the true and the predicted distributions. We validate our approach on a variety of tasks, including stereo disparity and depth estimation, and the downstream 3D object detection. Our approach drastically reduces the error in ambiguous regions, especially around object boundaries that greatly affect the localization of objects in 3D, achieving the state-of-the-art in 3D object detection for autonomous driving.

Update

  • 1 June 2021: Released all pretrained models. Added script to download KITTI Object Detection dataset.

Contents

Our Wasserstein loss W_loss can be easily plugged in existing stereo depth models to improve the training and obtain better results.

We release the code for CDN-PSMNet and CDN-SDN models.

Requirements

  1. Python 3.7
  2. Pytorch 1.2.0+
  3. CUDA
  4. pip install -r ./requirements.txt
  5. SceneFlow
  6. KITTI

Pretrained Models

Place the checkpoint folders in ./results.

Depth Models

Disparity Models

Datasets

You have to download the SceneFlow and KITTI datasets. The structures of the datasets are shown in below.

KITTI can be automatically downloaded using ./scripts/download_kitti.sh

SceneFlow Dataset Structure

SceneFlow
    | monkaa
        | frames_cleanpass
        | disparity
    | driving
        | frames_cleanpass
        | disparity
    | flyingthings3d
        | frames_cleanpass 
        | disparity

KITTI Object Detection Dataset Structure

KITTI
    | training
        | calib
        | image_2
        | image_3
        | velodyne
    | testing
        | calib
        | image_2
        | image_3

Generate soft-links of SceneFlow Datasets. The results will be saved in ./sceneflow folder. Please change to fakepath path-to-SceneFlow to the SceneFlow dataset location before running the script.

python sceneflow.py --path path-to-SceneFlow --force

Convert the KITTI velodyne ground truths to depth maps. Please change to fakepath path-to-KITTI to the SceneFlow dataset location before running the script.

python ./src/preprocess/generate_depth_map.py --data_path path-to-KITTI/ --split_file ./split/trainval.txt

Optionally download KITTI2015 datasets for evaluating stereo disparity models.

Training and Inference

We have provided all pretrained models Pretrained Models. If you only want to generate the predictions, you can directly go to step 3.

We use config files to simplify argument parsing. The default setting requires four gpus to train. You can use smaller batch sizes which are btrain and bval, if you don't have enough gpus.

We provide code for both stereo disparity and stereo depth models.

We optionally use Losswise to visualize training metrics. An API key can be obtained and added to a config key to enable it.

1 Train CDN-SDN from Scratch on SceneFlow Dataset

python ./src/main_depth.py -c src/configs/sceneflow_w1.config

The checkpoints are saved in ./results/stack_sceneflow_w1/.

Follow same procedure to train stereo disparity model, but use src/main_disp.py and change to a disparity config.

2 Train CDN-SDN on KITTI Dataset

python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --pretrain ./results/sceneflow_w1/checkpoint.pth.tar --datapath  path-to-KITTI/training/

Before running, please change the fakepath path-to-KITTI/ to the correct one. --pretrain is the path to the pretrained model on SceneFlow. The training results are saved in ./results/kitti_w1_train.

If you are working on evaluating CDN on KITTI testing set, you might want to train CDN on training+validation sets. The training results will be saved in ./results/sdn_kitti_trainval.

python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --pretrain ./results/sceneflow_w1/checkpoint.pth.tar \
    --datapath  path-to-KITTI/training/ --split_train ./split/trainval.txt \
    --save_path ./results/sdn_kitti_trainval

The disparity models can also be trained on KITTI2015 datasets using src/kitti2015_w1_disp.config.

3 Generate Predictions

Please change the fakepath path-to-KITTI. Moreover, if you use the our provided checkpoint, please modify the value of --resume to the checkpoint location.

  • a. Using the model trained on KITTI training set, and generating predictions on training + validation sets.
python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --resume ./results/sdn_kitti_train/checkpoint.pth.tar --datapath  path-to-KITTI/training/ \
    --data_list ./split/trainval.txt --generate_depth_map --data_tag trainval

The results will be saved in ./results/sdn_kitti_train/depth_maps_trainval/.

  • b. Using the model trained on KITTI training + validation set, and generating predictions on testing sets. You will use them when you want to submit your results to the leaderboard.

The results will be saved in ./results/sdn_kitti_trainval_set/depth_maps_trainval/.

# testing sets
python ./src/main_depth.py -c src/configs/kitti_w1.config \
    --resume ./results/sdn_kitti_trainval/checkpoint.pth.tar --datapath  path-to-KITTI/testing/ \
    --data_list=./split/test.txt --generate_depth_map --data_tag test

The results will be saved in ./results/sdn_kitti_trainval/depth_maps_test/.

4 Train 3D Detection with Pseudo-LiDAR

For training 3D object detection models, follow step 4 and after in the Pseudo-LiDAR_V2 repo https://github.com/mileyan/Pseudo_Lidar_V2.

Results

Results on the Stereo Disparity

Figure

3D Object Detection Results on KITTI leader board

Figure

Questions

Please feel free to email us if you have any questions.

Divyansh Garg ([email protected]), Yan Wang ([email protected]), Wei-Lun Chao ([email protected])

w-stereo-disp's People

Contributors

dd-iuonac avatar div99 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

w-stereo-disp's Issues

Missing imports

Hi,
I am trying to run the training myself using the instructions provided but there appear to be missing files.

After downloading the SceneFlow dataset, if I try and run

python ./src/main_depth.py -c src/configs/sceneflow_w1.config

I get the following error:

Traceback (most recent call last):
  File "./src/main_depth.py", line 21, in <module>
    import models
  File "C:\Code\I3DR\W-Stereo-Disp\src\models\__init__.py", line 1, in <module>
    from .full_res import PSMNet as basic
ModuleNotFoundError: No module named 'models.full_res

Looking in this file it seems that init.py is actually missing a lot of the imports

from .full_res import PSMNet as basic
from .stackhourglass import PSMNet as stackhourglass
from .stackhourglass_classif import PSMNet as stackhourglass_classif
from .stackhourglass_edge_aware import PSMNet as stackhourglass_edge_aware
from .stackhourglass_full import PSMNet as stackhourglass_full
from .stackhourglass_semantic import PSMNet as stackhourglass_semantic
from .stackhourglass_softmax_offset import PSMNet as stackhourglass_softmax_offset
from .stackhourglass_std import PSMNet as stackhourglass_std
from .stackhourglass_volume import PSMNet as stackhourglass_volume
from .stackhourglass_volume_large_off import PSMNet as stackhourglass_volume_large_off
from .stackhourglass_volume_multihead import PSMNet as stackhourglass_multihead
from .stackhourglass_volume_semantic import PSMNet as stackhourglass_volume_semantic
from .stackhourglass_win import PSMNet as stackhourglass_win

Most of these are missing. Should I have done something to generate these or have they been excluded?

I tried commenting out the missing imports and then I found an API key is required for losswise. This isn't a service I have used before but I registered for an account and added my personal API key and this let the script continue.

This seems to work and I am currently running the training. Was this the correct procedure?

Question about multi-modal GT

Hi, thank you for sharing the code and I have a few questions:

  1. Where can we find the generation of multi-modal ground truth (MM GT) as descripbed in the paper?
  2. Is there an implementation of Wasserstein Loss with MM GT as described by Eq. 12 in the paper?
  3. Did DSGN+CDN (the best 3D AP on KITTI) use MM GT or not? I'm using DSGN and trying to reproduce DSGN+CDN+Wasserstein Loss.

Data loading problem

When I use docker running on the server, I will encounter the following situation. I am stuck at this step. I still have not proceeded to the next step after two days. Reconfiguration of the environment and other methods are invalid. Can you help me?

root@c860f179a9eb:~/smd/WDSDE/W-Stereo-Disp# python ./src/main_depth.py -c src/configs/kitti_w1.config --resume ./results/sdn_kitti_trainval/checkpoint.pth.tar --datapath ./KITTI/testing/ --data_list=./split/test.txt --generate_depth_map --data_tag test
TPQAWUTNB
[2021-08-02 05:34:23 main_depth.py:165] INFO api_key: TPQAWUTNB
[2021-08-02 05:34:23 main_depth.py:165] INFO arch: stackhourglass_volume
[2021-08-02 05:34:23 main_depth.py:165] INFO btrain: 12
[2021-08-02 05:34:23 main_depth.py:165] INFO bval: 4
[2021-08-02 05:34:23 main_depth.py:165] INFO calib_value: 1017
[2021-08-02 05:34:23 main_depth.py:165] INFO checkpoint_interval: -1
[2021-08-02 05:34:23 main_depth.py:165] INFO config: src/configs/kitti_w1.config
[2021-08-02 05:34:23 main_depth.py:165] INFO data_list: ./split/test.txt
[2021-08-02 05:34:23 main_depth.py:165] INFO data_tag: test
[2021-08-02 05:34:23 main_depth.py:165] INFO data_type: depth
[2021-08-02 05:34:23 main_depth.py:165] INFO datapath: ./KITTI/testing/
[2021-08-02 05:34:23 main_depth.py:165] INFO dataset: kitti
[2021-08-02 05:34:23 main_depth.py:165] INFO depth_wise_loss: False
[2021-08-02 05:34:23 main_depth.py:165] INFO down: 2
[2021-08-02 05:34:23 main_depth.py:165] INFO dynamic_bs: False
[2021-08-02 05:34:23 main_depth.py:165] INFO epochs: 300
[2021-08-02 05:34:23 main_depth.py:165] INFO eval_interval: 50
[2021-08-02 05:34:23 main_depth.py:165] INFO evaluate: False
[2021-08-02 05:34:23 main_depth.py:165] INFO generate_depth_map: True
[2021-08-02 05:34:23 main_depth.py:165] INFO kitti2015: False
[2021-08-02 05:34:23 main_depth.py:165] INFO losswise_tag: finetune_w1_fix
[2021-08-02 05:34:23 main_depth.py:165] INFO lr: 0.001
[2021-08-02 05:34:23 main_depth.py:165] INFO lr_gamma: 0.1
[2021-08-02 05:34:23 main_depth.py:165] INFO lr_stepsize: [200]
[2021-08-02 05:34:23 main_depth.py:165] INFO maxdepth: 80
[2021-08-02 05:34:23 main_depth.py:165] INFO maxdisp: 192
[2021-08-02 05:34:23 main_depth.py:165] INFO pretrain: ./results/checkpoint.pth.tar
[2021-08-02 05:34:23 main_depth.py:165] INFO resume: ./results/sdn_kitti_trainval/checkpoint.pth.tar
[2021-08-02 05:34:23 main_depth.py:165] INFO save_path: ./results/kitti_w1_train
[2021-08-02 05:34:23 main_depth.py:165] INFO scale: 1
[2021-08-02 05:34:23 main_depth.py:165] INFO split_train: ./split/train.txt
[2021-08-02 05:34:23 main_depth.py:165] INFO split_val: ./split/subval.txt
[2021-08-02 05:34:23 main_depth.py:165] INFO start_epoch: 0
[2021-08-02 05:34:23 main_depth.py:165] INFO w_p: 1
[2021-08-02 05:34:23 main_depth.py:165] INFO warmup_epochs: 0
[2021-08-02 05:34:23 main_depth.py:209] INFO Number of model parameters: 5310496
[2021-08-02 05:34:28 main_depth.py:219] INFO => loading pretrain './results/checkpoint.pth.tar'
[2021-08-02 05:34:28 main_depth.py:227] INFO => loading checkpoint './results/sdn_kitti_trainval/checkpoint.pth.tar'
[2021-08-02 05:34:29 main_depth.py:235] INFO => loaded checkpoint './results/sdn_kitti_trainval/checkpoint.pth.tar' (epoch 300)
0%| | 0/1880 [00:00<?, ?it/s]

about kitti dataset image_3

How did you use Image3 as a training set in your training?Because there is no official label for image3, there is only label2 corresponding to image2.If it is generated by transformation, can you tell me how to operate it?

Question on pre-trained models

Hi, thank you for releasing the pre-trained depth models. Could you please advise if there is a script to re-produce CDN-DSGN in the current repo? Or are there instructions to do so?

Amazing work

Thanks for your contribution to this community!
When would you provide your pretrained weights?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.