Git Product home page Git Product logo

bi-tempered-loss's Introduction

Bi-Tempered Logistic Loss

This is not an officially supported Google product.

Overview of the method is here: Google AI Blogpost

Also, explore the interactive visualization that demonstrates the practical properties of the Bi-Tempered logistic loss.

Bi-Tempered logistic loss is a generalized softmax cross-entropy loss function with bounded loss value per sample and a heavy-tail softmax probability function.

Bi-tempered loss generalizes (with a bias correction term):

  • Zhang & Sabuncu. "Generalized cross entropy loss for training deep neural networks with noisy labels." In NeurIPS 2018.

which is recovered when 0.0 <= t1 <= 1.0 and t2 = 1.0. It also includes:

  • Ding & Vishwanathan. "t-Logistic regression." In NeurIPS 2010.

for t1 = 1.0 and t2 >= 1.0.

Bi-tempered loss is equal to the softmax cross entropy loss when t1 = t2 = 1.0. For 0.0 <= t1 < 1.0 and t2 > 1.0, bi-tempered loss provides a more robust alternative to the cross entropy loss for handling label noise and outliers.

TensorFlow and JAX

A replacement for standard logistic loss function: tf.losses.softmax_cross_entropy is available here

def bi_tempered_logistic_loss(activations,
                              labels,
                              t1,
                              t2,
                              label_smoothing=0.0,
                              num_iters=5):
  """Bi-Tempered Logistic Loss with custom gradient.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    labels: A tensor with shape and dtype as activations.
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    label_smoothing: Label smoothing parameter between [0, 1).
    num_iters: Number of iterations to run the method.
  Returns:
    A loss tensor.
  """

Replacements are also available for the transfer functions:

Tempered version of tf.nn.sigmoid and jax.nn.sigmoid:

def tempered_sigmoid(activations, t, num_iters=5):
  """Tempered sigmoid function.
  Args:
    activations: Activations for the positive class for binary classification.
    t: Temperature > 0.0.
    num_iters: Number of iterations to run the method.
  Returns:
    A probabilities tensor.
  """

Tempered version of tf.nn.softmax and jax.nn.softmax:

def tempered_softmax(activations, t, num_iters=5):
  """Tempered softmax function.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature > 0.0.
    num_iters: Number of iterations to run the method.
  Returns:
    A probabilities tensor.
  """

Citation

When referencing Bi-Tempered loss, cite this paper:

@inproceedings{amid2019robust,
  title={Robust bi-tempered logistic loss based on bregman divergences},
  author={Amid, Ehsan and Warmuth, Manfred KK and Anil, Rohan and Koren, Tomer},
  booktitle={Advances in Neural Information Processing Systems},
  pages={15013--15022},
  year={2019}
}

Contributions

We are eager to collaborate with you too! Please send us a pull request or open a github issue. Please see this doc for further details

bi-tempered-loss's People

Contributors

eamid avatar rohan-anil avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bi-tempered-loss's Issues

How are the labels corrupted?

Hi, I'm trying to reproduce your experiments on cifar100.
For label noise experiments, it's mentioned in paper that

we introduce noise by artificially corrupting a fraction of the labels and producing a new set of labels for each noise level.

But I'm still not sure how to corrupt these labels exactly. Are these labels shuffled or what? I hope that you can give me a few more instructions on this, a few lines of code will be the best.

Thanks a lot.

TF 2.0 Version

The current code for the loss function does not work directly in TF 2.0 - Are there any plans to port it over?

ValueError: Rank mismatch: Rank of labels (received 2) should equal rank of logits minus 1 (received 2)

I'm trying to make a custom loss function based on the sparse_bi_tempered_logistic_loss( ) function of this repository.

T_1 = 0.2
T_2 = 1.2
SMOOTH_FRACTION = 0.01
N_ITER = 5

def bi_tempered_loss(y_pred,y_true):
        return sparse_bi_tempered_logistic_loss(y_pred,y_true,T_1,T_2)

my_model.compile(loss=bi_tempered_loss, optimizer=keras.optimizers.Adam(lr=1e-4), metrics=['accuracy'])

The labels are integers from 0 to 4.

However, the following errors occurs when I fit my model:


ValueError: in user code:

    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    <ipython-input-44-efd725700411>:12 bi_tempered_loss  *
        return sparse_bi_tempered_logistic_loss(y_pred,y_true,T_1,T_2)
    /kaggle/working/loss.py:421 sparse_bi_tempered_logistic_loss  *
        loss_values = tf.cond(
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1180 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/cond_v2.py:85 cond_v2
        op_return_value=pred)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:986 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/nn_ops.py:4084 sparse_softmax_cross_entropy_with_logits
        (labels_static_shape.ndims, logits.get_shape().ndims))

    ValueError: Rank mismatch: Rank of labels (received 2) should equal rank of logits minus 1 (received 2).


I've tried using class-based custom loss implementation, but it gave the same error. Am I missing something?

How do I implement Tempered_softmax in CīŧŸ

Dear authors,

I used your loss function as my loss function, it work well! Now I need to complete further effects on C, but the implementation of Tempered_softmax seems to be non-real-time, so how do I implement it in C?

Use sigmod or tempered_sigmoid for prediction?

Dear authors,

I'm encountering an issue regarding to the prediction phase. Assume we use bi_tempered_binary_logistic_loss as our loss function, should we still use sigmod as function used for probability calculation, or we should use tempered_sigmoid? Thanks for clarification!

Accuracy results on cifar100

this paper reports 75.30 accuracy on the clean test set. But I obatin 78.16 accuracy on the same test set, I use resnet50 with SGD + momentum optimizer trained for 350 epoch.

trainning is too slow

when i adjust Temperature 1=0.2 Temperature 2=4 to handle random noise.
I find the training is too slow. Does anyone know why?

Why did you use Bergman divergence instead of KL divergence?

In the paper, you said, Bergman divergence is a generalization of the relative entropy.
If we use the KL divergence with your exp_t, log_t, then the loss shapes like the following image:
image

I derive this from your equation in the page number 5 in the middle.

This looks quite complicate to calculate, but I think this is not impossible.

How to calculate "simple integration" in Chapter 3

image
This image is from your paper.
You said "Via simple integration, ...".
I did it by myself but the constant term $\frac{1}{2-t}$ is ambiguous.
How did you get this term? What constraint did you apply?

Accuracy results on MNIST

Hi,
I was reading your paper and got surprised by a few results.
First, the accuracy of what you call logistic loss is much lower than what you can typically obtain on MNIST. With the same model as your describe in the paper, I am able to obtain 99.3% accuracy on the MNIST test set in 10 epochs of SGD training. How can you obtain 98.08% accuracy with 500 epochs of training on clean samples?
Second, I was trying to reproduce the results with label corruption and I am seeing the same issue. With 50% label corruption, the same model and 10 epochs of SGD training, I am able to obtain 98% accuracy on the test set.

How do you explain such low results in your experiments?

why 5 is the default num_iters?

cannot understand why existing a while with 5 as iteration time of method, since nothwhere including the original paper mentions this super parameter.

noisy instances

Hi,
thanks for sharing your implementation. Is it possible to identify the noisy instances (return the noisy IDs or the clean set)?

Thanks!

loss_test.py fails in test_gradient_error

Hello,
When I tried to change the values of t1 and t2 in loss_test.py line 98 (test_gradient_error function)to be:
t1 = 0.5
t2 = 1.0
instead
t1 = 0.5
t2 = 1.5
The test failed.
Here is the ERROR message:
Traceback (most recent call last):
File "/home/tanya/code/loss_test.py", line 123, in test_gradient_error
activations, labels, 0.5, 1.0)
File "/home/tanya/code/bitempered_loss/loss.py", line 174, in _internal_bi_tempered_logistic_loss
beta = 1.0 + one_minus_t1
UnboundLocalError: local variable 'one_minus_t1' referenced before assignment

It looks like the one_minus_t1 is calculated only in case that t2>1.0, and so it fails when t2=1.0 and t1<1.0

Thanks.

Nan loss during training

Dear authors,

I used bi_tempered_binary_logistic_loss as my NN's loss function, and I noticed that after a few hundreds steps training, some loss becomes nan. I printed the value outputted by bi_tempered_binary_logistic_loss and confirmed the nan value was from bi_tempered_binary_logistic_loss. Any idea why? Do I need to adjust the parameter setting and how? Thank you for answering!!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤ī¸ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.