Git Product home page Git Product logo

Comments (3)

goodfeli avatar goodfeli commented on May 25, 2024 7

When computing the loss, you should not use preds at all. Preds involves a sigmoid so it can saturate and fail to learn. That is why your implementation has a bug where if you set the learning rate slightly too high, it will get stuck printing the same loss value (of roughly 22) over and over again. A correctly implemented cross entropy cost cannot get stuck at a specific high value.

My suggestion is not to use:

self.d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(self.D, tf.ones_like(self.D))

My suggestion is to use:

self.d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)),

where

self.D = tf.sigmoid(self.D_logits).

Your cross entropy calculation is incorrect because you compute the preds first, and the sigmoid can get numerically rounded to 0 or 1. To compensate for that, you have to do tricks like adding and subtracting eps:
https://github.com/carpedm20/DCGAN-tensorflow/blob/master/ops.py#L35
to avoid taking the log of zero, etc.
If implemented correctly, the cross-entropy does not need any eps tricks. The tricks make your cross entropy calculation only slightly inaccurate for most inputs. The important thing is that they make the gradient of the cross entropy very inaccurate for extreme inputs.
tf.nn.sigmoid_cross_entropy_with_logits has the correct implementation.

There's a longer explanation of cross-entropy and why it can't get stuck if implemented properly in section 6.2.1 of the textbook here: http://www.deeplearningbook.org/contents/mlp.html

from dcgan-tensorflow.

carpedm20 avatar carpedm20 commented on May 25, 2024

I think the variable names are wrong (function name should be just binary_cross_entropy and logits and target should be change) but the calculation seems to be right for me.

In tensorflow document, tf.nn.sigmoid_cross_entropy_with_logits is calculated as:

let x = logits, z = targets.
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))

and in my understanding, it wants to calculate:

p_target * -log(p_pred) + (1 - p_target) * -log(1 - sigmoid(p_pred)).

self.D is tf.nn.sigmoid(h4) like this which makes it to be p_pred and binary_cross_entropy_with_logits is calculated as:

tf.reduce_mean(-(logits * tf.log(targets) + (1. - logits) * tf.log(1. - targets)))
= tf.reduce_mean(-(tf.ones_like(self.D) * tf.log(self.D) + (1. - tf.ones_like(self.D)) * tf.log(1. - self.D)))

and I think this is right because tf.ones_like(self.D) is a target probability and self.D is a predicted probability which will calculate cross entropy of them.

As your suggestion, if we use:

self.d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(self.D, tf.ones_like(self.D)),

self.D (probability) will be passed into sigmoid(x) and the calculation will looks like:

z * -log(sigmoid(self.D)) + (1 - z) * -log(1 - sigmoid(self.D))
= p_target * -log(sigmoid(p_pred)) + (1 - p_target) * -log(1 - sigmoid(p_pred))

and this seems to be wrong way to calculate cross entropy.

I'm not an expert like you so please give me an advice if there is anything wrong in my explanation.

from dcgan-tensorflow.

carpedm20 avatar carpedm20 commented on May 25, 2024

@goodfeli Thanks for nice explanation! I think I need to fix other models too.

from dcgan-tensorflow.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.