Git Product home page Git Product logo

sceloss-reproduce's Introduction

SCELoss-PyTorch

Official Repo: https://github.com/YisenWang/symmetric_cross_entropy_for_noisy_labels
Reproduce result for ICCV2019 paper "Symmetric Cross Entropy for Robust Learning with Noisy Labels"

Update

In the tensorflow version Official Repo, the model uses l2 weight decay of 0.01 on model.fc1, which will gives a better results. The code has been updated, now it should shows similar performance as in the paper.

How To Run

Arguments
  • --loss: 'SCE', 'CE'
  • --nr: 0.0 to 1.0 specify the nosie rate.
  • --dataset_type: 'cifar10' or 'cifar100'
  • --alpha: alpha for SCE
  • --beta: beta for SCE
  • --seed: random seed
  • --version: For experiment notes

Example for 0.4 Symmetric noise rate with SCE loss

# CIFAR10
$ python3 -u train.py  --loss         SCE               \
	                     --dataset_type cifar10           \
                       --l2_reg       1e-2              \
                       --seed         123               \
                       --alpha        0.1               \
                       --beta         1.0               \
                       --version      SCE0.4_CIFAR10    \
                       --nr           0.4

# CIFAR100
$ python3 -u train.py  --lr           0.01              \
                       --loss         SCE               \
                       --dataset_type cifar100          \
                       --l2_reg       1e-2              \
                       --seed         123               \
                       --alpha        6.0               \
                       --beta         1.0               \
                       --version      SCE0.4_CIFAR100   \
                       --nr           0.4

Results on CIFAR10

Result of best Epoch

Loss 0.0 0.2 0.4 0.6 0.8
CE 92.68 84.70 72.77 54.14 31.23
SCE 92.05 89.96 84.65 73.77 36.28

Results on CIFAR100

Loss 0.0 0.2 0.4 0.6 0.8
CE 73.84 61.70 42.88 20.47 4.88
SCE 73.57 62.31 46.50 24.00 12.51

sceloss-reproduce's People

Contributors

hanxunh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

sceloss-reproduce's Issues

Error when predictions have extra dimensions other than classes

torch.nn.CrossEntropyLoss takes prediction of shape (minibatch, Classes, d1, d2 , ...). This line leads to error if the d1, d2 dimensions exist.

rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))

The predictions are of shape (minibatch, classes, d1, d2, ...) but the label_one_hot variable is of shape (minibatch, d1, d2, ..., classes). I think the label_one_hot tensor should be transposed to make the class dimension the second dimension.

I have a question about downloading the dataset.

Dear Hanxun Huang,

I sincerely thank you for your PyTorch implementation. I have a question for it.
I have a problem downloading a dataset. Error code is as shown below:

Traceback (most recent call last):
File "train.py", line 233, in
train()
File "train.py", line 187, in train
dataset = DatasetGenerator(batchSize=args.batch_size,
File "/home/user/SCELoss-Reproduce/dataset.py", line 168, in init
self.data_loaders = self.loadData()
File "/home/user/SCELoss-Reproduce/dataset.py", line 217, in loadData
train_dataset = cifar10Nosiy(root=self.dataPath,
File "/home/user/SCELoss-Reproduce/dataset.py", line 69, in init
super(cifar10Nosiy, self).init(root, transform=transform, target_transform=target_transform)
File "/home/user/.pyenv/versions/anaconda3-2020.02/envs/pytorch/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 61, in init
raise RuntimeError('Dataset not found or corrupted.' +
RuntimeError: Dataset not found or corrupted. You can use download=True to download it

I confirmed that download=True. So I tried a different method. I manually downloaded the dataset from the web site, and I set to download=False in code (in dataset.py) and ran it. However, I got the same error.

Sorry to ask this of you when you are busy but I appreciate your help.
Thanks so much.

How to regulate A?

Hello, I saw the setting of parameter A and it depends on alpha in the article. But A has been set in the experiment (eg,: A =-6 or -4). I would like to know where A is set in the code.

About backward

Hi, I want to know which file is backpropagation in, about SCELoss. thanks!

How to regulate the alpha and the beta of SCEloss?

Hi, I want to apply SCEloss to other visual tasks, such as semantic segmentation.How should I consider tweaking the alpha and beta of SCEloss? And,I can't find the parameter A of the paper in the code.

why not flip the label for rce?

I would like to ask why not just flip the label to calculate rce?

if ce = self.cross_entropy(pred, labels), wont rce just be rce = self.cross_entropy(labels,pred) ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.