Git Product home page Git Product logo

h3lio5 / linguistic-style-transfer-pytorch Goto Github PK

View Code? Open in Web Editor NEW
68.0 3.0 21.0 160.1 MB

Implementation of "Disentangled Representation Learning for Non-Parallel Text Style Transfer(ACL 2019)" in Pytorch

Python 100.00%
style-transfer pytorch latent-representations text-style-transfer variational-autoencoder adversarial-learning disentangled-representations disentanglement acl natural-language-generation

linguistic-style-transfer-pytorch's Introduction

Linguistic Style Transfer

Implementation of the paper Disentangled Representation Learning for Non-Parallel Text Style Transfer(link) in Pytorch

Abstract

This paper tackles the problem of disentangling the latent representations of style and content in language models. We propose a simple yet effective approach, which incorporates auxiliary multi-task and adversarial objectives, for style prediction and bag-of-words prediction, respectively. We show, both qualitatively and quantitatively, that the style and content are indeed disentangled in the latent space. This disentangled latent representation learning can be applied to style transfer on non-parallel corpora. We achieve high performance in terms of transfer accuracy, content preservation, and language fluency, in comparision to various previous approaches.

To get a basic overview of the paper, read the summary.

1.Setup Instructions and Dependencies

You may setup the repository on your local machine by either downloading it or running the following line on terminal.

git clone https://github.com/h3lio5/linguistic-style-transfer-pytorch.git

All dependencies required by this repo can be downloaded by creating a virtual environment with Python 3.7 and running

python3 -m venv .env
source .env/bin/activate
pip install -r requirements.txt
pip install -e .

Note: Run all the commands from the root directory.

2.Training Model from Scratch

To train your own model from scratch, run

python train.py 
  • The parameters for your experiment are all set by defualt. But you are free to set them on your own by editing the config.py file.
  • The training script will create a folder checkpoints as specified in your config.py file.
  • This folder will contain all model parameters saved after each epoch.

3. Transfering Text Style from Trained Models

To transfer text style of a sentence from trained models, run

python generate.py 

The user will be prompted to enter the source sentence and the target style on running the above command:
Example:

Please enter the source sentence: the book is good
Please enter the target style: pos or neg: neg
Your style transfered sentence is: the book is boring

4.Repository Overview

This repository contains the following files and folders

  1. images: Contains media for readme.md.

  2. linguistic-style-transfer-pytorch/data_loader.py: Contains helper functions that load data.

  3. linguistic-style-transfer-pytorch/model.py: Contains code to build the model.

  4. linguistic-style-transfer-pytorch/config.py: Contains information about various file paths and model configuration.

  5. linguistic-style-transfer-pytorch/utils/vocab.py: Contains code to build the vocabulary and word embeddings.

  6. linguistic-style-transfer-pytorch/utils/preprocess.py Contains code to preprocess the data.

  7. linguistic-style-transfer-pytorch/utils/train_w2v.py: Contains code to train word2vec embeddings from scratch on the downloaded data.

  8. generate.py: Used to generate and save images from trained models.

  9. train.py: Contains code to train models from scratch.

  10. requirements.txt: Lists dependencies for easy setup in virtual environments.

5.Training and Inference

Illustration of training and inference.
training_and_inference

Resources

  • Original paper Disentangled Representation Learning for Non-Parallel Text Style Transfer (link)
  • tensorflow implementation by the author link

linguistic-style-transfer-pytorch's People

Contributors

h3lio5 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

linguistic-style-transfer-pytorch's Issues

file not found

Hi @h3lio5 ,
I get this error:
FileNotFoundError: [Errno 2] No such file or directory: '/content/linguistic-style-transfer-pytorch/linguistic_style_transfer_pytorch/data/word_embeddings.npy'

I looked for the file but it is not there. Where I can find it?

Best!
P.S. Great work by the way!

Issue in preprocess.py

In preprocess.py shouldn't it be
labels.file.write("neg" + "\n") in line number 57
If it is not the case, negative labels will not generate , right? @h3lio5

the final hidden state's representation is wrong

@h3lio5 I'm sorry to make a issue to bother you.
I have found a small but maybe serious bug.
at model.py line 95

sentence_emb = output[torch.arange(output.size(0)), seq_lengths-1]
# get content and style embeddings from the sentence embeddings,i.e. final_hidden_state

For Bi-GRU, the final_hidden_state h_n has the forward pass and the backward pass data. And the backward pass data actually is equal to the output[0]'s backward pass data. The forward pass data is equal to the output[-1]'s forward data. (see at https://discuss.pytorch.org/t/rnn-output-vs-hidden-state-dont-match-up-my-misunderstanding/43280)
So, it seems to cannot represent the final_hidden_state just use the output[-1] (i.e sentence_emb).
And it cause the generate.py 's issue which has been reported.
I will try to fix the two issues~ @michaeldu1 @arijit1410 @Doragd
๐Ÿ˜„

metrics

Hi, thanks for your great work! It really helped me understand the paper much better. I was able to get the model running but I was wondering if your code kept track of any automatic evaluation metrics as it was training. Thank you!

Do not use this code. Use this instead!

This is for all people who are trying to replicate this paper in pytorch. This repo is hopelessly broken, even after fixing all bugs there are still severe implementation oversights and training is impossible. Kudos to the author for starting this initiative but please do not waste your time here.

We are attempting to build another implementation of this paper in pytorch and it is nearing completion. https://github.com/sharan21/disentangled-style-transfer-vae

Please feel free to improve this repo and contribute. Thanks to @h3lio5 for still providing support and functions to work with. :)

Shape error in generate.py

Hi, thanks for your great work!! I was trying to reproduce your results and was able to train the model successfully. However, when trying to run generate.py, I had a shape error:

`Enter the source sentence: the book is good
Enter the target style: pos or neg: neg
sentence shape is torch.Size([2, 1, 256])
Traceback (most recent call last):
File "generate.py", line 42, in
target_tokenids = model.transfer_style(token_ids, target_style_id)

File "/data/home/surviv/linguistic-style-transfer-pytorch/linguistic_style_transfer_pytorch/model.py", line 511, in transfer_style
final_hidden_state)

File "/data/home/surviv/linguistic-style-transfer-pytorch/linguistic_style_transfer_pytorch/model.py", line 227, in get_content_emb
mu = self.content_mu(sentence_emb)

File "/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, **kwargs)

File "/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)

File "/data/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/functional.py", line 1371, in linear
output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [2 x 256], m2: [512 x 128] at /opt/conda/conda-bld/pytorch_1565272269120/work/aten/src/TH/generic/THTensorMath.cpp:752`

It seems like the error is in model.py in line 227. Do you know why this shape error exists? Thanks so much for your work and help, I look forward to your response!!

why kl loss is not used?

hello, thanks for you code about this paper. it helps me a lot.
but from your code at
linguistic_style_transfer_pytorch/model.py line 157: vae_and_classifier_loss

 vae_and_classifier_loss = mconfig.content_adversary_loss_weight * content_entropy_loss + \
            mconfig.style_adversary_loss_weight * style_entropy_loss + \
            mconfig.style_multitask_loss_weight * style_mul_loss + \
            mconfig.content_multitask_loss_weight * content_mul_loss + \
            reconstruction_loss

why kl loss is not used?

by the way,
at train.py line 19

model = AdversarialVAE(inference=False, weight=weights)

This parameter inference does not seem to define?

Please do not use this code

for anyone who wishes to use this code, please don't use study the original implementation of the authors.

A mismatch between the implementation and the paper

In the implementation of the decoder, here in linguistic-style-transfer-pytorch/linguistic_style_transfer_pytorch/model.py, in this function "def generate_sentences(self, input_sentences, latent_emb, inference=False):"
starting from Line 428, it seems that you are concatenating the content & style embedding to each word embedding of the sentence.
However, based on my understanding of the paper, the concatenation is between the content representation & the style representation, and this forms the hidden states of the decoder(last paragraph of 3.1 in the paper). In contrast, hidden states of the decoder in this implementation are simply zeros as in Line 442 and 463(both training and generating mode), which is mentioned by #5.

       input_sentences = torch.cat(
            (sos_token_tensor, input_sentences), dim=1)
        sentence_embs = self.dropout(self.embedding(input_sentences))
        # Make the latent embedding compatible for concatenation
        # by repeating it for max_seq_len + 1(additional one bcoz <sos> tokens were added)
        latent_emb = latent_emb.unsqueeze(1).repeat(
            1, mconfig.max_seq_len+1, 1)
        gen_sent_embs = torch.cat(
            (sentence_embs, latent_emb), dim=2)

Another problem is the shape error when generating sentences under inference mode, the latent_emb isn't concatenated to word_emb as that in the training mode. However, I begin to wonder whether that's right to concatenate word embeddings & latent embeddings in the first place.

Since I didn't look into the code fully & carefully, I might be getting this wrong.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.