Git Product home page Git Product logo

Comments (10)

krasserm avatar krasserm commented on September 21, 2024

Can you please share a minimal running example?

from super-resolution.

Shubham3101 avatar Shubham3101 commented on September 21, 2024

Sure here it is,

    def get_gan_network(discriminator, input_shape, generator):
        discriminator.trainable = False
        gan_input = tf.keras.layers.Input(shape=input_shape)
        x = generator(gan_input)
        gan_output = discriminator(x)
        gan = tf.keras.models.Model(inputs=gan_input, outputs=[x,gan_output])
        return gan

    class ContentLoss(object):

        def __init__(self, image_shape):
        
            self.loss = tf.keras.losses.MeanSquaredError()
            self.image_shape = image_shape
            vgg19 = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet', input_shape=self.image_shape)
            vgg19.trainable = False
            for l in vgg19.layers:
                l.trainable = False
            self.model = tf.keras.models.Model(inputs=vgg19.input, outputs=vgg19.layers[20].output)
            self.model.trainable = False

        # computes VGG loss or content loss
        def vgg_loss(self, y_true, y_pred):
            y_true = tf.keras.applications.vgg19.preprocess_input(y_true)
            y_pred = tf.keras.applications.vgg19.preprocess_input(y_pred)
        
            y_true_fts = self.model(y_true)/12.75
            y_pred_fts = self.model(y_pred)/12.75
        
            loss = self.loss(y_true_fts, y_pred_fts)
    
            return loss

    gen = generator()
    dis = discriminator()
    
    content_loss = ContentLoss(shape).vgg_loss
    lossD = 'binary_crossentropy'
    lossG = [content_loss, "binary_crossentropy"]
    optimizerG = tf.keras.optimizers.Adam(
        tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5]))
    optimizerD = tf.keras.optimizers.Adam(0.0001)
    
    dis.compile(loss=lossD, optimizer=optimizerD, metrics=["accuracy"])
    
    gan = get_gan_network(dis, input_shape, gen)
    gan.compile(loss=lossG, loss_weights=[1., 1e-3], optimizer=optimizerG)
    
    for j in range(len(train_generator)):
        train_data = train_generator.__getitem__(j)
        lr = train_data[0]
        hr = train_data[1]
        batch_size = lr.shape[0]
    
        # Adversarial ground truths
        real = np.ones(batch_size)
        fake = np.zeros(batch_size)
        real_smooth = np.ones(batch_size) - np.random.random_sample(batch_size) * 0.2
        fake_smooth = np.random.random_sample(batch_size) * 0.2
    
        generated_images_sr = gen.predict(lr)
    
        # --------------------- Train the Discriminator ---------------------
        d_loss_real = dis.train_on_batch(hr, real_smooth)
        d_loss_fake = dis.train_on_batch(generated_images_sr, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
        #  --------------------- Train the Generator ---------------------
        real = np.ones(batch_size)
        gan_loss = self.gan.train_on_batch(lr, [hr, real])
    

from super-resolution.

krasserm avatar krasserm commented on September 21, 2024

Your example trains the generator from scratch but it should be pre-trained (alone) with mean_absolute_error first. Is this only missing in your example or did you skip that at all?

from super-resolution.

Shubham3101 avatar Shubham3101 commented on September 21, 2024

In your implementation you trained it first with "mse" loss.

And yes I am loading the pretrained weights, I just didn't mentioned it here.

from super-resolution.

Shubham3101 avatar Shubham3101 commented on September 21, 2024

Btw, Pre training the generator is not required according to the paper, right?

from super-resolution.

Shubham3101 avatar Shubham3101 commented on September 21, 2024

I think the discriminator is still getting too confident(even after one sided label smoothing), so I may have smoothen the labels for both sides but it is not advisable (according to lot of sources).

How come you are not facing this issue, given that you are not using any label smoothing or introduction of any noise in discriminator input. Because all the different models(GANs) I have trained, none of them worked with hard labels and noise free discriminator training.

from super-resolution.

krasserm avatar krasserm commented on September 21, 2024

In your implementation you trained it first with "mse" loss.

Yes, SRReset is trained with mse, EDSR and WDSR (which can also be fine-tuned with SRGAN discriminator) are trained with mae.

Btw, Pre training the generator is not required according to the paper, right?

No, the paper does generator pre-training: "We employed the trained MSE-based SRResNet network as initialization for the generator when training the actual GAN to avoid undesired local optima."

I think the discriminator is still getting too confident ...

According to section 4.4. in this tutorial, it is fine when the discriminator overpowers the generator. During SRGAN training, I'm also seeing a low discriminator loss although it's not zero.

How come you are not facing this issue, given that you are not using any label smoothing or introduction of any noise in discriminator input. Because all the different models(GANs) I have trained, none of them worked with hard labels and noise free discriminator training.

First of all, the generator is primarily trained with a content loss with a weight 1000x higher than the weight of the adversarial generator loss. So this it is not really comparable to a "classic" GAN training where only an adversarial generator loss is used. Second, the implementation in this repository is just a plain re-implementation of the paper. I didn't try any tweaks.

I'm closing this ticket for now as it seems the issues you're facing are not related to the code in this repository. Please re-open if you think this is not the case.

from super-resolution.

Shubham3101 avatar Shubham3101 commented on September 21, 2024

Can you share loss values at the end of your training? Content loss and adversial loss?

from super-resolution.

krasserm avatar krasserm commented on September 21, 2024
  • content loss: 0.066 (low variance)
  • adversarial generator loss 3.301 (low variance)
  • adversarial discriminator loss: moving between 0.1 and 0.5

from super-resolution.

Shubham3101 avatar Shubham3101 commented on September 21, 2024

My content loss fluctuates between 0.060 to 0.065 and generator adversarial loss fluctuates between 0.03 to 0.09 and discriminator adversarial loss goes as down as 1.e-10.

Is this normal or something is wrong?

I know its not related to your project, just need your opinion.

Thanks

from super-resolution.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.