Comments (10)
Can you please share a minimal running example?
from super-resolution.
Sure here it is,
def get_gan_network(discriminator, input_shape, generator):
discriminator.trainable = False
gan_input = tf.keras.layers.Input(shape=input_shape)
x = generator(gan_input)
gan_output = discriminator(x)
gan = tf.keras.models.Model(inputs=gan_input, outputs=[x,gan_output])
return gan
class ContentLoss(object):
def __init__(self, image_shape):
self.loss = tf.keras.losses.MeanSquaredError()
self.image_shape = image_shape
vgg19 = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet', input_shape=self.image_shape)
vgg19.trainable = False
for l in vgg19.layers:
l.trainable = False
self.model = tf.keras.models.Model(inputs=vgg19.input, outputs=vgg19.layers[20].output)
self.model.trainable = False
# computes VGG loss or content loss
def vgg_loss(self, y_true, y_pred):
y_true = tf.keras.applications.vgg19.preprocess_input(y_true)
y_pred = tf.keras.applications.vgg19.preprocess_input(y_pred)
y_true_fts = self.model(y_true)/12.75
y_pred_fts = self.model(y_pred)/12.75
loss = self.loss(y_true_fts, y_pred_fts)
return loss
gen = generator()
dis = discriminator()
content_loss = ContentLoss(shape).vgg_loss
lossD = 'binary_crossentropy'
lossG = [content_loss, "binary_crossentropy"]
optimizerG = tf.keras.optimizers.Adam(
tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5]))
optimizerD = tf.keras.optimizers.Adam(0.0001)
dis.compile(loss=lossD, optimizer=optimizerD, metrics=["accuracy"])
gan = get_gan_network(dis, input_shape, gen)
gan.compile(loss=lossG, loss_weights=[1., 1e-3], optimizer=optimizerG)
for j in range(len(train_generator)):
train_data = train_generator.__getitem__(j)
lr = train_data[0]
hr = train_data[1]
batch_size = lr.shape[0]
# Adversarial ground truths
real = np.ones(batch_size)
fake = np.zeros(batch_size)
real_smooth = np.ones(batch_size) - np.random.random_sample(batch_size) * 0.2
fake_smooth = np.random.random_sample(batch_size) * 0.2
generated_images_sr = gen.predict(lr)
# --------------------- Train the Discriminator ---------------------
d_loss_real = dis.train_on_batch(hr, real_smooth)
d_loss_fake = dis.train_on_batch(generated_images_sr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# --------------------- Train the Generator ---------------------
real = np.ones(batch_size)
gan_loss = self.gan.train_on_batch(lr, [hr, real])
from super-resolution.
Your example trains the generator from scratch but it should be pre-trained (alone) with mean_absolute_error
first. Is this only missing in your example or did you skip that at all?
from super-resolution.
In your implementation you trained it first with "mse" loss.
And yes I am loading the pretrained weights, I just didn't mentioned it here.
from super-resolution.
Btw, Pre training the generator is not required according to the paper, right?
from super-resolution.
I think the discriminator is still getting too confident(even after one sided label smoothing), so I may have smoothen the labels for both sides but it is not advisable (according to lot of sources).
How come you are not facing this issue, given that you are not using any label smoothing or introduction of any noise in discriminator input. Because all the different models(GANs) I have trained, none of them worked with hard labels and noise free discriminator training.
from super-resolution.
In your implementation you trained it first with "mse" loss.
Yes, SRReset is trained with mse
, EDSR and WDSR (which can also be fine-tuned with SRGAN discriminator) are trained with mae
.
Btw, Pre training the generator is not required according to the paper, right?
No, the paper does generator pre-training: "We employed the trained MSE-based SRResNet network as initialization for the generator when training the actual GAN to avoid undesired local optima."
I think the discriminator is still getting too confident ...
According to section 4.4. in this tutorial, it is fine when the discriminator overpowers the generator. During SRGAN training, I'm also seeing a low discriminator loss although it's not zero.
How come you are not facing this issue, given that you are not using any label smoothing or introduction of any noise in discriminator input. Because all the different models(GANs) I have trained, none of them worked with hard labels and noise free discriminator training.
First of all, the generator is primarily trained with a content loss with a weight 1000x higher than the weight of the adversarial generator loss. So this it is not really comparable to a "classic" GAN training where only an adversarial generator loss is used. Second, the implementation in this repository is just a plain re-implementation of the paper. I didn't try any tweaks.
I'm closing this ticket for now as it seems the issues you're facing are not related to the code in this repository. Please re-open if you think this is not the case.
from super-resolution.
Can you share loss values at the end of your training? Content loss and adversial loss?
from super-resolution.
- content loss: 0.066 (low variance)
- adversarial generator loss 3.301 (low variance)
- adversarial discriminator loss: moving between 0.1 and 0.5
from super-resolution.
My content loss fluctuates between 0.060 to 0.065 and generator adversarial loss fluctuates between 0.03 to 0.09 and discriminator adversarial loss goes as down as 1.e-10.
Is this normal or something is wrong?
I know its not related to your project, just need your opinion.
Thanks
from super-resolution.
Related Issues (20)
- Scale x2
- Progressive training
- The training time of per iteration changes when the batch_size is changed
- A tensorflow error in wdsr
- ValueError: The channel dimension of the inputs should be defined. Found `None`.
- Pre-train model for super-resolution by x2/x3
- PSNR doesn't change on custom dataset
- How to retain the audio of the vidoes? HOT 1
- Super-resolution comparison.ipynb
- Training Time
- x1 model?
- AssertionError HOT 3
- ValueError: axes don't match array when loading WDSR pre-trained weights
- How can I run it using distributed trainning?
- How Should I customize data for Custom Dataset Training? HOT 2
- How should I create caches file? HOT 1
- How to plot train vs. val loss over epochs
- Weight normalization for tf versions that are incompatible with tfa HOT 2
- missing evaluation file
- weird ouput on pre-trained weights HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from super-resolution.