Git Product home page Git Product logo

rgb-d_salient_object_detection's Introduction

RGB-D Salient Object Detection Using Conditional GAN

This code is a Pytorch implementation of RGB-D salient object detection using cGAN. Two-stream network generates the pixel-wise saliency map and PatchGAN discriminator learns to determine whether the generated saliency map is real or fake.

Requirements

  • Python 3
  • Pytorch 1.4
  • Torchvision 0.5.0
  • Pillow

Experimental enviroment

Getting started

  1. Clone this repository

     git clone https://github.com/wj1224/rgb-d_salient_object_detection.git
    
  2. Prepare datasets We use NLPR and NJUDS2000 RGB-D saliency detection datasets to train the networks. (additionally DUT-OMRON, HKU-IS, and MSRA10K RGB saliency datasets are used with synthetic depth maps that was generated using pix2pix.

  3. Training

    cd rgb-d_salient_object_detection
    python main.py \
        --mode train \
        --input_dir path/to/trainset \
        --output_dir path/to/logs \
        --max_epochs 100 \
        --cuda \
        --[args]
    

    See below for more args.

  4. Testing

    python main.py \
        --mode test \
        --input_dir path/to/testset \
        --output_dir path/to/output_saliency_maps \
        --checkpoint path/to/saved_logs \
        --n_epochs 100 \
        --cuda
    
  5. More details of args. There are several options on running main.py with --[args].

    --mode ["train", "test] : train or test mode selection
    --input_dir [path/to/imgs] : Folder path which containing input images
    --output_dir [path/to/output] : Folder path to save logs in training or output images in testing
    --checkpoint  [path/to/logs] : Folder path to resume training or use for testing
    --n_epochs [100] : Load checkpoint from trained models with "n_epochs"
    --max_epochs [100] : Number of epochs in training step
    --batch_size [16] : Size of mini-batch
    --cuda : Using GPU
    --threds : Number of threds for data loading
    --ngf [64] : Number of filters on first convolution layer of the generator
    --ndf [16] : Number of filters on first convolution layer of the discriminator
    --lr [0.0002] : Learning rate of Adam optimizer
    --beta1 [0.9] : Momentum of Adam optimizer
    --lambda_g [10.0] : Weight on CrossEntropyLoss term of generator loss function
    --lambda_gp[1.0] : Weight on gradient penalty term of discriminator loss function
    
  6. Pretrained model If you want to testing with pretrained model, download this and put it path/to/logs. The model was trained by using datasets as described in step 2. You can simply test the model with the following command.

    python main.py \
        --mode test \
        --input_dir path/to/testset \
        --output_dir path/to/output_saliency_maps \
        --checkpoint path/to/pretrained_model \
        --pretrained \
        --cuda
    

Architecture

Results

  • NLPR testset

  • NJUDS2000 testset

  • F-measure scores Compared to not using depth maps completely and not using only synthetic depth maps in training step.

    Dataset Only RGB RGB + real depth map RGB + real and synthetic depth map
    NLPR 0.7705 0.7780 0.8103
    NJUDS2000 0.8014 0.8405 0.8567

More details

Please see this.

rgb-d_salient_object_detection's People

Contributors

wj1224 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.