Git Product home page Git Product logo

Comments (9)

xiankgx avatar xiankgx commented on May 30, 2024 1

I'm really not sure about that.

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 30, 2024

@xiankgx hmm, i don't understand your second paragraph here

the training objective is different depending on whether you are predicting x_start or not https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L1515

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 30, 2024

i am also totally not confident on the new objective, as evidenced by the in-line comments in the code, so if you find another paper that uses this objective, i would definitely be appreciative

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 30, 2024

ok, time to walk my dog 🐕 be back later!

from dalle2-pytorch.

xiankgx avatar xiankgx commented on May 30, 2024

I understand the model can either:

  • predict x_start,
  • or predict noise, which can be used to predict x_start,

However, from the code, we can see that no matter what we are predicting, x_recon is x_start.

def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
        pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)

        if predict_x_start:
            x_recon = pred
        else:
            x_recon = self.predict_start_from_noise(x, t = t, noise = pred)

        if clip_denoised and not predict_x_start:
            x_recon.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

x_recon is x_start because of lines:

if predict_x_start:
    x_recon = pred
else:
    x_recon = self.predict_start_from_noise(x, t = t, noise = pred)

Hence, I don't understand then why are we clamping x_recon conditioned on whether model predicts x_start directly or not (via noise) in the following lines:

if clip_denoised and not predict_x_start:
    x_recon.clamp_(-1., 1.)

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 30, 2024

oh I understand! Yes this makes sense for decoder, but not for the diffusion prior (although for prior, do you think we could also clamp with l2norm?)

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 30, 2024

Thanks I'll make the change once I'm back home

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 30, 2024

I'm really not sure about that.

when in doubt, make it a hyperparameter 77fa34e ;)

from dalle2-pytorch.

lucidrains avatar lucidrains commented on May 30, 2024

@xiankgx so i realized the reason i didn't clip in the Decoder is because i introduced latent diffusion - however, there is an improved VQGan variant out there that proposes to l2norm the codebook, so perhaps if we figure out that l2norm clamping works, then we can also add that to the sampling steps as extra guardrail

from dalle2-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.