Git Product home page Git Product logo

gol's Introduction

Gumbel Optimised Loss for Long-tailed Instance Segmentation

This is the official implementation of Gumbel Optimised Loss for Long-tailed Instance Segmentation for ECCV2022 accepted paper.

Introduction

Major advancements have been made in the field of object detection and segmentation recently. However, when it comes to rare categories, the state-of-the-art methods fail to detect them, resulting in a significant performance gap between rare and frequent categories. In this paper, we identify that Sigmoid or Softmax functions used in deep detectors are a major reason for low performance and are suboptimal for long-tailed detection and segmentation. To address this, we develop a Gumbel Optimized Loss (GOL), for long-tailed detection and segmentation. It aligns with the Gumbel distribution of rare classes in imbalanced datasets, considering the fact that most classes in long-tailed detection have low expected probability. The proposed GOL significantly outperforms the best state-of-the-art method by 1.1% on AP, and boosts the overall segmentation by 9.0% and detection by 8.0%, particularly improving detection of rare classes by 20.3%, compared to Mask-RCNN, on LVIS dataset.

Gumbel Activation using (M)ask-RCNN, (R)esnet,Resne(X)t, (C)ascade Mask-RCNN and (H)ybrid Task Cascade. Performance of Gumbel activation

Gumbel Cross Entropy (simplified)

def gumbel_cross_entropy(pred,
                         label,reduction):
    """Calculate the Gumbel CrossEntropy loss.
    Args:
        pred (torch.Tensor): The prediction.
        label (torch.Tensor): one-hot encoded
    Returns:
        torch.Tensor: The calculated loss.
    """
    pred=torch.clamp(pred,min=-4,max=10)
    pestim= 1/(torch.exp(torch.exp(-(pred))))
    loss = F.binary_cross_entropy(
        pestim, label.float(), reduction=reduction)
    loss=torch.clamp(loss,min=0,max=20)

    return loss

Tested with

  • python==3.8.12
  • torch==1.7.1
  • torchvision==0.8.2
  • mmdet==2.21.0
  • lvis
  • Tested on CUDA 10.2 and RHEL 8 system

Getting Started

Create a virtual environment
conda create --name mmdet pytorch=1.7.1 -y
conda activate mmdet
  1. Install dependency packages
conda install torchvision -y
conda install pandas scipy -y
conda install opencv -y
  1. Install MMDetection
pip install openmim
mim install mmdet==2.21.0
  1. Clone this repo
git clone https://github.com/kostas1515/GOL.git
cd GOL
  1. Create data directory, download COCO 2017 datasets at https://cocodataset.org/#download (2017 Train images [118K/18GB], 2017 Val images [5K/1GB], 2017 Train/Val annotations [241MB]) and extract the zip files:
mkdir data
cd data
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip

#download and unzip LVIS annotations
wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip
wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip

  1. modify mmdetection/configs/base/datasets/lvis_v1_instance.py and make sure data_root variable points to the above data directory, e.g., data_root = '<user_path>'

Training

To Train on multiple GPUs use tools/dist_train.sh to launch training on multiple GPUs:
./tools/dist_train.sh ./configs/<experiment>/<variant.py> <#GPUs>

E.g: To train GOL on 4 GPUs use:

./tools/dist_train.sh ./configs/gol/droploss_normed_mask_r50_rfs_4x4_2x_gumbel.py 4

Testing

To test GOL:

./tools/dist_test.sh ./experiments/droploss_normed_mask_rcnn_r50_rfs_4x4_2x_gumbel/droploss_normed_mask_r50_rfs_4x4_2x_gumbel.py ./experiments/droploss_normed_mask_r50_rfs_4x4_2x_gumbel/latest.pth 4 --eval bbox segm

Reproduce

To reproduce the results on the the paper with Sigmoid, Softmax and Gumbel activation run:
./tools/dist_train.sh ./configs/activations/r50_4x4_1x.py <#GPUs>
./tools/dist_train.sh ./configs/activations/r50_4x4_1x_softmax.py <#GPUs>
./tools/dist_train.sh ./configs/activations/gumbel/gumbel_r50_4x4_1x.py <#GPUs>

It will give a Table similar to this:

Method AP APr APc APf APb
Sigmoid 16.4 0.8 12.7 27.3 17.2
Softmax 15.2 0.0 10.6 26.9 16.1
Gumbel 19.0 4.9 16.8 27.6 19.1

Pretrained Models

Method AP APr APc APf APb Model Output
GOL_r50_v0.5 29.5 22.5 31.3 30.1 28.2 weights log|config
GOL_r50_v1 27.7 21.4 27.7 30.4 27.5 weights log|config
GOL_r101_v1 29.0 22.8 29.0 31.7 29.2 weights log|config

Citation

 @inproceedings{alexandridis2022long,
   title={Long-tailed Instance Segmentation using Gumbel Optimized Loss},
   author={Alexandridis, Konstantinos Panagiotis and Deng, Jiankang and Nguyen, Anh and Luo, Shan},
   booktitle={European Conference on Computer Vision},
   pages={353--369},
   year={2022},
   organization={Springer}
 }

Acknowledgements

This code uses the mmdet framework. It also uses EQLv2 and DropLoss. Thank you for your wonderfull work!

gol's People

Contributors

kostas1515 avatar

Stargazers

Xiaobing Han avatar Sang avatar Constantin Seibold avatar WJ avatar  avatar  avatar 王国军 avatar  avatar Xiao Gu avatar  avatar Jiong Wang avatar An-zhi WANG avatar Chris avatar Manolis Pitsikalis avatar JiankangDeng avatar

Watchers

James Cloos avatar  avatar

Forkers

cv-seg jordis-ai2

gol's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.