Git Product home page Git Product logo

conta's Introduction

Causal Intervention for Weakly Supervised Semantic Segmentation

The main code for:

Causal Intervention for Weakly Supervised Semantic Segmentation. Dong Zhang, Hanwang Zhang, Jinhui Tang, Xiansheng Hua, and Qianru Sun. NeurIPS, 2020. [CONTA]

Requirements

  • PyTorch 1.2.0, torchvision 0.4.0, and more in requirements.txt
  • PASCAL VOC 2012 devkit and COCO 2014
  • 8 NVIDIA GPUs with more than 1024MB of memory

Usage

Install python dependencies

pip install -r requirements.txt

Download PASCAL VOC 2012 and COCO

To generate pseudo_mask:

For pseudo-mask generaction, we follow the method IRNet without the instance-wise CAM generation step.

cd pseudo_mask & python run_sample.py
  • You can either mannually edit the file, or specify commandline arguments.
  • Replace the ground_truth annotation in PASCAL VOC 2012 with the generated pseudo_mask.

To train segmentation model:

cd segmentation & python main.py train --config-path configs/voc12.yaml

To evaluate the performance on validation set:

python main.py test --config-path configs/voc12.yaml \
    --model-path data/models/voc12/deeplabv2_resnet101_msc/train_aug/final_model.pth

To re-evaluate with a CRF post-processing:

python main.py crf --config-path configs/voc12.yaml

Common setting:

  • Model: DeepLab v2 with ResNet-101 backbone. Dilated rates of ASPP are (6, 12, 18, 24). Output stride is 8 times.
  • GPU: All the GPUs visible to the process are used. Please specify the scope with CUDA_VISIBLE_DEVICES=0,1,2,3.
  • Multi-scale loss: Loss is defined as a sum of responses from multi-scale inputs (1x, 0.75x, 0.5x) and element-wise max across the scales. The unlabeled class is ignored in the loss computation.
  • Learning rate: Stochastic gradient descent (SGD) is used with momentum of 0.9 and initial learning rate of 2.5e-4. Polynomial learning rate decay is employed; the learning rate is multiplied by (1-iter/iter_max)**power at every 10 iterations.
  • Monitoring: Moving average loss (average_loss in Caffe) can be monitored in TensorBoard.
  • Preprocessing: Input images are randomly re-scaled by factors ranging from 0.5 to 1.5, padded if needed, and randomly cropped to 321x321.
  • You can find more useful tools in /tools/xxx.

Training batch normalization

This codebase only supports DeepLab v2 training which freezes batch normalization layers, although v3/v3+ protocols require training them. If training their parameters on multiple GPUs as well in your projects, please install the extra library below.

pip install torch-encoding

Batch normalization layers in a model are automatically switched in libs/models/resnet.py.

try:
    from encoding.nn import SyncBatchNorm
    _BATCH_NORM = SyncBatchNorm
except:
    _BATCH_NORM = nn.BatchNorm2d

Inference Demo

To process a single image:

python tools/demo.py single \
    --config-path configs/voc12.yaml \
    --model-path model.pth \
    --image-path image.jpg

To run on a webcam:

python tools/demo.py live \
    --config-path configs/voc12.yaml \
    --model-path model.pth

Citation

If you find the code useful, please consider citing our paper using the following BibTeX entry.

@InProceedings{dong_2020_conta,
author = {Dong, Zhang and Hanwang, Zhang and Jinhui, Tang and Xiansheng, Hua and Qianru, Sun},
title = {Causal Intervention for Weakly Supervised Semantic Segmentation},
booktitle = {NeurIPS},
year = 2020
}

References

  1. L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, A. L. Yuille. DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs. IEEE TPAMI, 2018.
    Project / Code / Paper

  2. M. Everingham, L. Van Gool, C. K. I. Williams, J. Winn, A. Zisserman. The PASCAL Visual Object Classes (VOC) Challenge. IJCV, 2010.
    Project / Paper

  3. Ahn, Jiwoon and Cho, Sunghyun and Kwak, Suha. Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations. CVPR, 2019.
    Project / Paper

TO DO

  • Training code for MS-COCO
  • Code refactoring
  • Release the checkpoint

Questions

Please contact '[email protected]'

conta's People

Contributors

dongzhang89 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.