Git Product home page Git Product logo

maskgan's Introduction

MaskGAN: Better Text Generation via Filling in the ______

Code for MaskGAN: Better Text Generation via Filling in the ______ published at ICLR 2018.

Requirements

  • TensorFlow >= v1.5

Instructions

Warning: The open-source version of this code is still in the process of being tested. Pretraining may not work correctly.

For training on PTB:

  1. Pretrain a LM on PTB and store the checkpoint in /tmp/pretrain-lm/. Instructions WIP.

  2. Run MaskGAN in MLE pretraining mode. If step 1 was not run, set language_model_ckpt_dir to empty.

python train_mask_gan.py \
 --data_dir='/tmp/ptb' \
 --batch_size=20 \
 --sequence_length=20 \
 --base_directory='/tmp/maskGAN' \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.00074876,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=1,gen_learning_rate_decay=0.95" \
 --mode='TRAIN' \
 --max_steps=100000 \
 --language_model_ckpt_dir=/tmp/pretrain-lm/ \
 --generator_model='seq2seq_vd' \
 --discriminator_model='rnn_zaremba' \
 --is_present_rate=0.5 \
 --summaries_every=10 \
 --print_every=250 \
 --max_num_to_print=3 \
 --gen_training_strategy=cross_entropy \
 --seq2seq_share_embedding
  1. Run MaskGAN in GAN mode. If step 2 was not run, set maskgan_ckpt to empty.
python train_mask_gan.py \
 --data_dir='/tmp/ptb' \
 --batch_size=128 \
 --sequence_length=20 \
 --base_directory='/tmp/maskGAN' \
 --mask_strategy=contiguous \
 --maskgan_ckpt='/tmp/maskGAN' \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.000038877,gen_learning_rate_decay=1.0,gen_full_learning_rate_steps=2000000,gen_vd_keep_prob=0.33971,rl_discount_rate=0.89072,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=2,dis_pretrain_learning_rate=0.005,critic_learning_rate=5.1761e-7,dis_vd_keep_prob=0.71940" \
 --mode='TRAIN' \
 --max_steps=100000 \
 --generator_model='seq2seq_vd' \
 --discriminator_model='seq2seq_vd' \
 --is_present_rate=0.5 \
 --summaries_every=250 \
 --print_every=250 \
 --max_num_to_print=3 \
 --gen_training_strategy='reinforce' \
 --seq2seq_share_embedding=true \
 --baseline_method=critic \
 --attention_option=luong
  1. Generate samples:
python generate_samples.py \
 --data_dir /tmp/ptb/ \
 --data_set=ptb \
 --batch_size=256 \
 --sequence_length=20 \
 --base_directory /tmp/imdbsample/ \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,gen_vd_keep_prob=0.33971" \
 --generator_model=seq2seq_vd \
 --discriminator_model=seq2seq_vd \
 --is_present_rate=0.0 \
 --maskgan_ckpt=/tmp/maskGAN \
 --seq2seq_share_embedding=True \
 --dis_share_embedding=True \
 --attention_option=luong \
 --mask_strategy=contiguous \
 --baseline_method=critic \
 --number_epochs=4

Contact for Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.