Git Product home page Git Product logo

Comments (11)

DianeBouchacourt avatar DianeBouchacourt commented on August 23, 2024 4

Has this issue been solved ? Training on dSprites, I also get negative tc loss

from disentangling-vae.

YannDubs avatar YannDubs commented on August 23, 2024 2

Awesome thanks for checking. Few comments:

1/ What do you mean by "+ve" and "-ve" ? What is ve ?

2/ Looking back at it it seems that I actually had the correct code and then incorporated the problem it in a late night push ( #43 )

Here's what I had before my changes:

def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist,n_data, is_mss=False):
    batch_size, hidden_dim = latent_sample.shape

    # calculate log q(z|x)
    log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

    # calculate log p(z)
    # mean and log var is 0
    zeros = torch.zeros_like(latent_sample)
    log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)

    if not self.is_mss:
        log_qz, log_prod_qzi = _minibatch_weighted_sampling(latent_dist,
                                                            latent_sample,
                                                            n_data)

    else:
        log_qz, log_prod_qzi = _minibatch_stratified_sampling(latent_dist,
                                                              latent_sample,
                                                              n_data)

    return log_pz, log_qz, log_prod_qzi, log_q_zCx


def _minibatch_weighted_sampling(latent_dist, latent_sample, data_size):
    """
    Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
    weighted sampling.

    Parameters
    ----------
    latent_dist : tuple of torch.tensor
        sufficient statistics of the latent dimension. E.g. for gaussian
        (mean, log_var) each of shape : (batch_size, latent_dim).

    latent_sample: torch.Tensor
        sample from the latent dimension using the reparameterisation trick
        shape : (batch_size, latent_dim).

    data_size : int
        Number of data in the training set

    References 
    -----------
       [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
       autoencoders." Advances in Neural Information Processing Systems. 2018.
    """
    batch_size = latent_sample.size(0)

    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False) -
                    math.log(batch_size * data_size)).sum(dim=1)
    log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False
                             ) - math.log(batch_size * data_size)

    return log_qz, log_prod_qzi


def _minibatch_stratified_sampling(latent_dist, latent_sample, data_size):
    """
    Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
    stratified sampling.
    
    Parameters
    -----------
    latent_dist : tuple of torch.tensor
        sufficient statistics of the latent dimension. E.g. for gaussian
        (mean, log_var) each of shape : (batch_size, latent_dim).

    latent_sample: torch.Tensor
        sample from the latent dimension using the reparameterisation trick
        shape : (batch_size, latent_dim).

    data_size : int
        Number of data in the training set

    References 
    -----------
       [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
       autoencoders." Advances in Neural Information Processing Systems. 2018.
    """
    batch_size = latent_sample.size(0)

    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    log_iw_mat = log_importance_weight_matrix(batch_size, data_size).to(latent_sample.device)
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size, batch_size, 1) +
                                   mat_log_qz, dim=1, keepdim=False).sum(1)

    return log_qz, log_prod_qzi

from disentangling-vae.

sisodia-a avatar sisodia-a commented on August 23, 2024 2

I didn't test that yet. I was just trying to see from the math/code where am I getting the error.

from disentangling-vae.

sisodia-a avatar sisodia-a commented on August 23, 2024 1

Using some random matrices (code attached
temp.txt
temp.txt

), I used your code as well as Ricky Chen's code to compare what is happening.

I found

MWS
log_qz != logqz_ricky
log_prod_qzi != logqz_prodmarginals_ricky

MSS
logqz_prodmarginals_ricky_mss == log_prod_qzi_mss
logqz_ricky_mss != log_qz_mss

So, when I use your code with is_mss=true, then I get -ve tc_loss and with is_mss=false, I get -ve mi_loss and -ve tc_loss.
I ran it on dsprites dataset with batchsize 128.

Then I changed the _get_log_pz_qz_prodzi_qzCx function in your code to make it similar to Ricky Chen's code.

batch_size, hidden_dim = latent_sample.shape

# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)

mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size * n_data))       ## Ankit - modified
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size * n_data)).sum(1) ## Ankit - modified

# is_mss=False
if is_mss:                                                                                                                ## Ankit - modified
    # use stratification                                                                                                  ## Ankit - modifiede
    log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)                                ## Ankit - modified
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)                                        ## Ankit - modified
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1)      ## Ankit - modified

return log_pz, log_qz, log_prod_qzi, log_q_zCx

Then I get +ve losses for everything when is_mss=True but then I get -ve dw_kl_loss term.

from disentangling-vae.

sisodia-a avatar sisodia-a commented on August 23, 2024

https://github.com/rtqichen/beta-tcvae/ calculates
logqz_prodmarginals = (logsumexp(_logqz, dim=1, keepdim=False) - math.log(batch_size * dataset_size)).sum(1)
logqz = (logsumexp(_logqz.sum(2), dim=1, keepdim=False) - math.log(batch_size * dataset_size))
in case of # minibatch weighted sampling

and in case of # minibatch stratified sampling, they do
logiw_matrix = Variable(self._log_importance_weight_matrix(batch_size, dataset_size).type_as(_logqz.data))
logqz = logsumexp(logiw_matrix + _logqz.sum(2), dim=1, keepdim=False)
logqz_prodmarginals = logsumexp(logiw_matrix.view(batch_size, batch_size, 1) + _logqz, dim=1, keepdim=False).sum(1)

so in this codebase, shouldn't we also do (in case of NOT is_mss)

log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size*n_data))       
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size*n_data)).sum(1)

and in case of (is_mss)

    log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)                   
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)                            
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1)    

from disentangling-vae.

YannDubs avatar YannDubs commented on August 23, 2024

Thanks @UserName-AnkitSisodia!
I think you might be right (I am taking a sum instead of marginalizing in the log space), but It's been a long time so I'll have to double-check this w-e.

Did you test it with these changes?

from disentangling-vae.

YannDubs avatar YannDubs commented on August 23, 2024

which is (I believe) exactly what you tested.

  • Does it also work for is_mss =False?

  • Just to be sure I understand, are you saying that with MSS this makes dw_kl_loss become negative ?

  • did you see any impact on the qualitative samples when training a model that way ?

from disentangling-vae.

sisodia-a avatar sisodia-a commented on August 23, 2024

Yes, this makes the code exactly same. Once these changes are made, I get negative dw_kl_loss term in case of _minibatch_weighted_sampling. For _minibatch_stratified_sampling, I am getting all loss terms as positive. I tested on dsprites.

from disentangling-vae.

YannDubs avatar YannDubs commented on August 23, 2024

and qualitatively do you see any differences?

from disentangling-vae.

shi-yu-wang avatar shi-yu-wang commented on August 23, 2024

I also got the negative loss with the DSprites data

from disentangling-vae.

shi-yu-wang avatar shi-yu-wang commented on August 23, 2024

tc loss

from disentangling-vae.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.