Comments (3)
When computing the loss, you should not use preds at all. Preds involves a sigmoid so it can saturate and fail to learn. That is why your implementation has a bug where if you set the learning rate slightly too high, it will get stuck printing the same loss value (of roughly 22) over and over again. A correctly implemented cross entropy cost cannot get stuck at a specific high value.
My suggestion is not to use:
self.d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(self.D, tf.ones_like(self.D))
My suggestion is to use:
self.d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)),
where
self.D = tf.sigmoid(self.D_logits).
Your cross entropy calculation is incorrect because you compute the preds first, and the sigmoid can get numerically rounded to 0 or 1. To compensate for that, you have to do tricks like adding and subtracting eps:
https://github.com/carpedm20/DCGAN-tensorflow/blob/master/ops.py#L35
to avoid taking the log of zero, etc.
If implemented correctly, the cross-entropy does not need any eps tricks. The tricks make your cross entropy calculation only slightly inaccurate for most inputs. The important thing is that they make the gradient of the cross entropy very inaccurate for extreme inputs.
tf.nn.sigmoid_cross_entropy_with_logits has the correct implementation.
There's a longer explanation of cross-entropy and why it can't get stuck if implemented properly in section 6.2.1 of the textbook here: http://www.deeplearningbook.org/contents/mlp.html
from dcgan-tensorflow.
I think the variable names are wrong (function name should be just binary_cross_entropy
and logits and target should be change) but the calculation seems to be right for me.
In tensorflow document, tf.nn.sigmoid_cross_entropy_with_logits
is calculated as:
let x = logits, z = targets.
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
and in my understanding, it wants to calculate:
p_target * -log(p_pred) + (1 - p_target) * -log(1 - sigmoid(p_pred)).
self.D
is tf.nn.sigmoid(h4)
like this which makes it to be p_pred and binary_cross_entropy_with_logits is calculated as:
tf.reduce_mean(-(logits * tf.log(targets) + (1. - logits) * tf.log(1. - targets)))
= tf.reduce_mean(-(tf.ones_like(self.D) * tf.log(self.D) + (1. - tf.ones_like(self.D)) * tf.log(1. - self.D)))
and I think this is right because tf.ones_like(self.D)
is a target probability and self.D
is a predicted probability which will calculate cross entropy of them.
As your suggestion, if we use:
self.d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(self.D, tf.ones_like(self.D)),
self.D
(probability) will be passed into sigmoid(x)
and the calculation will looks like:
z * -log(sigmoid(self.D)) + (1 - z) * -log(1 - sigmoid(self.D))
= p_target * -log(sigmoid(p_pred)) + (1 - p_target) * -log(1 - sigmoid(p_pred))
and this seems to be wrong way to calculate cross entropy.
I'm not an expert like you so please give me an advice if there is anything wrong in my explanation.
from dcgan-tensorflow.
@goodfeli Thanks for nice explanation! I think I need to fix other models too.
from dcgan-tensorflow.
Related Issues (20)
- input _fname_pattern"*.jpg" Synatx Error: Invalid Syntax error in line 91 in main.py
- Why the kernel size of discriminator is 4?
- raise Exception("[!] Entire dataset size is less than the configured batch_size") Exception: [!] Entire dataset size is less than the configured batch_size
- why my model is not converge after 300 epochs HOT 1
- checkpoint not found HOT 2
- What are the in/output node names for Generator and Discriminator? HOT 1
- Solved some problems in my repo/解决了一些问题
- raise Exception("Checkpoint not found in " + FLAGS.checkpoint_dir) Exception: Checkpoint not found in ./out\20200526.133337 - data - retina\checkpoint HOT 5
- Training and Test generating black squares HOT 2
- There are two bugs in the transform function in the utils.py HOT 2
- How to save discriminator network? HOT 2
- Can't create checkpoint
- cannot generate when testing
- failed to teat
- failed to test HOT 1
- NameError:name 'PIL' is not defined HOT 1
- ValueError: could not broadcast input array from shape (1,2048) into shape (98,1024) HOT 2
- TypeError: 'NoneType' object is not subscriptable
- How to generate larger images? HOT 1
- InvalidArgumentError (see above for traceback): Nan in summary histogram for: HistogramSummary_2 [[Node: HistogramSummary_2 = HistogramSummary[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](HistogramSummary_2/tag, discriminator_1/Sigmoid)]] HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dcgan-tensorflow.