Git Product home page Git Product logo

wae-rnf-lm's Introduction

Riemannian Normalizing Flow on WAE

Code for our NAACL2019 Paper "Riemannian Normalizing Flow on Variational Wasserstein Autoencoder for Text Modeling" https://arxiv.org/abs/1904.02399

Author: Prince Zizhuang Wang and William Yang Wang

An example when latent space does not reflect input space. Left: a manifold that is highly curved in the central region. The yellow line is the geodesic (shortest) path connecting two sample points shown on the manifold. Right: The projection of manifold into 2D latent space, where the color brightness indicates curvature with respect to the manifold. The green line is the geodesic path if taking the curvature into account, while the blue line is the geodesic path if we regard latent space as Euclidean. Middle: The corresponding geodesic paths projected back from latent space to manifold. The white line corresponds to the straight geodesic path in Euclidean space. It is far longer than the true geodesic on manifold since it does not take the curvature into account in latent space.

Running the code

Requirements

  • python 3.6
  • pytorch 1.0.0
  • spacy 2.0.12
  • torchtext 0.3.0

Training

train on ptb

$ python main.py --dist normal --kla --mmd --kernel im --flow --n_flows 3 --center --reg im --t 
0.8 --mmd_w 10 --data data/ptb

train on yahoo

$ python main.py --dist normal --embed_dim 512 --hidden_dim 1024 --kla --center --flow --mmd --t 0.8 --mmd_w 10 --reg im --data data/yahoo

train on yelp

$ python main.py --dist normal --kla --center --flow --mmd --t 0.8 --mmd_w 10 --reg im --data data/yelp

Options

Option Usage Value (Range)
kla use kl annealing True or False
center use clusters to compute MMD True or False
flow use Normalizing Flow True or False
MMD use Wasserstein distance True or False
enc_type encoder model lstm or gru
de_type decoder model lstm or gru
t kl divergence weight default = 0.8
mmd_w mmd weight default = 10
dist choice of prior and posterior normal or vmf
kernel choice of kernel for mmd g or im
reg choice of kernel for rnf g or im

Acknowledgement

wae-rnf-lm's People

Contributors

kingofspace0wzz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

wae-rnf-lm's Issues

Question about KL Computation

Hi,

Thanks a lot for sharing this work!

May I know why you choose to compute the KL divergence as the difference between two entropies at this line? I think this is not aligned with the definition of KL divergence. May I know whether I misunderstand some factors here?

Thanks

Intuition behind pretraining to find centers

So in the paper it is mentioned that a standard KL annealed VAE is trained to get an initial approximation of the manifold.

Is the initial VAE supposed to have the same Encoder Decoder architectures as the RNF WAE ? In any case since the initializations, parameters etc. could be different how does it guarantee that the initial centers learned are the actual centers which occur during the training of the RNF WAE ?

How to work on new data?

Very nice work and thanks for sharing, but i have a few questions.
In your paper, in order to maximize the regularized jacobian we need to calculate the center of posterior codes z. In your code the centers was pretrained and we just load it at the beginning.
Since I want to work on some new data, i modify NormalizingFlows's forward function like this:

def forward(self, z, centers=None):
        log_det_jacobian = []
        penalty = []
        if centers is None:
            with torch.no_grad():
                _, centers = lloyd(z, 20)

        for flow in self.flows:
            z, j, d = flow(z, centers, reg=self.reg, band=self.band)
            log_det_jacobian.append(j)
            penalty.append(d)

        return z, sum(log_det_jacobian), sum(penalty)

but the training process become much slower than before, i'm wondering if you have some idea to work on new data? Do i need to firstly train a VAE to calculate the center of posterior codes z and then load it at the beginning of training a wae-rnf?
Thanks for you help.

Generating sentences from curved space

Hi, I have a query about the generation process. In a standard VAE, for generating random samples you sample from the prior standard normal distribution and feed it to the decoder for greedy decoding.

In this paper should you pass this prior to the flow to get the sample or directly feed the standard normal sample to the decoder?

Query about the KL divergence calculation with RNF

Hi,I have been reading the paper and looking at the code, and I don't understand how the KL is being calculated in the loss term.

The loss calculation has a KL term with flow:
(q_z.log_prob(z0).sum() - p_z.log_prob(z).sum())

Can you explain why z0 is used for posterior and z is used for prior? And the flow_kld function seems to use z for both, how is that different ?

Questions about your code

Hi @kingofspace0wzz ,

Thanks for your code. It seems that in model.py, there is no definition for: self.z2h(z) and self.lookup, and they are used in
def generate(self, z, max_length, sos_id).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.