Git Product home page Git Product logo

saliency's Introduction

Tidying Deep Saliency Prediction Architectures

This repository contains Pytorch Implementation of SimpleNet and MDNSal. Appearing in the proceedings of the 21st International Conference on Intelligent Robots and Systems (IROS).

Cite

Please cite with the following Bibtex code:

@inproceedings{Navya-IROS-2020, 
               AUTHOR = {Navyasri Reddy, Samyak Jain, Pradeep Yarlagadda, Vineet Gandhi}, 
               TITLE = {Tidying Deep Saliency Prediction Architectures}, 
               BOOKTITLE = {IROS}, 
               YEAR = {2020}
}

Abstract

Learning computational models for visual attention (saliency estimation) is an effort to inch machines/robots closer to human visual cognitive abilities. Data-driven efforts have dominated the landscape since the introduction of deep neural network architectures. In deep learning research, the choices in architecture design are often empirical and frequently lead to more complex models than necessary. The complexity, in turn, hinders the application requirements. In this paper, we identify four key components of saliency models, i.e., input features, multi-level integration, readout architecture, and loss functions. We review the existing state of the art models on these four components and propose novel and simpler alternatives. As a result, we propose two novel end-to-end architectures called SimpleNet and MDNSal, which are neater, minimal, more interpretable and achieve state of the art performance on public saliency benchmarks. SimpleNet is an optimized encoder-decoder architecture and brings notable performance gains on the SALICON dataset (the largest saliency benchmark). MDNSal is a parametric model that directly predicts parameters of a GMM distribution and is aimed to bring more interpretability to the prediction maps. The proposed saliency models run at 25fps, making them ideal for real-time applications.

Architecture

SimpleNet Architecture

MDNSal Architecture

Testing

Clone this repository and download the pretrained weights of SimpleNet, for multiple encoders, trained on SALICON dataset from this link. The trained weights for MobileNetV2 can be found here.

Then just run the code using

$ python3 test.py --val_img_dir path/to/test/images --results_dir path/to/results --model_val_path path/to/saved/models

This will generate saliency maps for all images in the images directory and dump these maps into results directory

Training

For training the model from scratch, download the pretrained weights of PNASNet from here and place these weights in the PNAS/ folder. Run the following command to train

$ python3 train.py --dataset_dir path/to/dataset 

The dataset directory structure should be

└── Dataset  
    ├── fixations  
    │   ├── train  
    │   └── val  
    ├── images  
    │   ├── train  
    │   └── val  
    ├── maps  
        ├── train  
        └── val  

For training the model with MIT1003 or CAT2000 dataset, first train the model with SALICON dataset and finetune the model weights on MIT1003 or CAT2000 dataset.

Experiments

  • Multiple Encoders

For training the model, we provide encoders based out of PNASNet, DenseNet-161, VGG-16 and ResNet-50. Run the command -

$ python3 train.py --enc_model <model> --train_enc <boolean value> 
<model> : {"pnas", "densenet", "resnet", "vgg", "mobilenet"}

train_enc is 1 if we want to finetune the encoder and 0 otherwise.

Similarly for testing the model,

$ python3 test.py --enc_model <model> --model_val_path path/to/pretrained/model --save_results <binary> --validate <binary> 

If you want to save the results of the generated map make save_results flag to 1 and if you want to evaluate the model quantitatively make the validate flag to 1.

  • Multiple Loss functions

For the training the model with a combination of loss functions, run the following command -

$ python3 train.py --<loss_function> True --<loss_function>_coeff <coefficient of the loss>
<loss_function> : {"kldiv", "cc", "nss", "sim"}

By default the loss function is KLDiv with coefficient 1.0

Quantitative Results

  • SALICON Test

The results of our models on SALICON test dataset can be viewed here under the name SimpleNet and MDNSal. Comparison with other state-of-the-art saliency detection models

  • MIT Test

Comparison with other state-of-the-art saliency detection models on MIT300 test set

Qualitative Results

Contact

If any question, please contact [email protected], [email protected] or [email protected] , or use public issues section of this repository

License

This code is distributed under MIT LICENSE.

saliency's People

Contributors

samyak0210 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

saliency's Issues

Qusetions about the test resutls on SALICON test set.

In this paper, the author said "The model trained on SALICON was finetuned using MIT1003 and CAT2000." I am wondering that if the model is fine-tuned on MIT1003 and CAT2000 when generating the resutls on SALICON 2017 test set?
In short:
on SALICON: model trained on SALICON -> test on SALICON.
on MIT300: model trianed on SALICON+MIT1003 -> test on SALICON.
or like that
on SALICON: model trained on SALICON+MIT1003+CAT2000 -> test on SALICON.
...

Loss weights for reproducing results

Hi,

I would like to know what are the coefficient values you used when combining multiple losses, as it is not stated in the article.
Is it the default values stated in the code (i.e. 1 for KLDiv, -1 for CC, etc...)?

Thanks!

Detailed architecture of ReadOut and code for MDNSal?

Thanks for the nice work.

Could you please release the code for MDNSal?

"The readout architecture con- sists of a convolutional layer to reduce the number of channels followed by a ReLU. " But the detailed number of channels and kernel size is not clear. Thanks.

NSS Metric/Loss During Test

Hi,

Thank you very much for the great work and for publishing the code. I have a question about the ground truth for the NSS metric.

I have evaluated your SALICON pre-trained models with the given commands. I can achieve your reported results except NSS metric. I get around 1.69 for the validation set while it is stated as 1.93 in the paper. I recognized that in test.py at line 100, you use "gt" for calculating the NSS loss. I suppose "gt" is the continuous saliency map and "fixations" is the binary fixation map. After I change "gt" to "fixations" in this line, I could get an NSS score of 1.927.

I would like to confirm this approach with you.

Best wishes,
Bahar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.