Git Product home page Git Product logo

transformer-xl's Introduction

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

This repository contains the code in both PyTorch and TensorFlow for our paper

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution)

Preprint 2018

TensorFlow

  • The source code is in the tf/ folder, supporting (1) single-node multi-gpu training, and (2) multi-host TPU training.
  • Besides the source code, we also provide pretrained "TensorFlow" models with state-of-the-art (SoTA) performances reported in the paper.
  • Please refer to tf/README.md for details.

PyTorch

  • The source code is in the pytorch/ folder, supporting single-node multi-gpu training via the module nn.DataParallel.
  • Please refer to pytorch/README.md for details.

Results

Transformer-XL achieves new state-of-the-art results on multiple language modeling benchmarks. Transformer-XL is also the first to break through the 1.0 barrier on char-level language modeling. Below is a summary.

Method enwiki8 text8 One Billion Word WT-103 PTB (w/o finetuning)
Previous Best 1.06 1.13 23.7 20.5 55.5
Transformer-XL 0.99 1.08 21.8 18.3 54.5

Acknowledgement

A large portion of the getdata.sh script comes from the awd-lstm repo. Happy Language Modeling :)

transformer-xl's People

Contributors

cbockman avatar cclauss avatar ijkilchenko avatar kimiyoung avatar lopuhin avatar stefan-it avatar yongbowin avatar zihangdai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

transformer-xl's Issues

tf code question?

Good works! I have two question about your tf codes.

The first:
In the paper, query vector is calculated using previous layer's hidden state rather than the concatenated pre-layer's memory and hidden state. However, in the tf code, I found the query vector is calculated the same as key vector and value vector.

image
image

The second:
Each layers has memory tensor with shape [mem_len, batch_size, d_model]. When calculating query, key and value vector, the input vector of tf.layers.dens is the concatenation of current layer's memory and pre-layers' output. which seems be conflict with the paper. Besides, why stop gradient in _cache_mem method rather than in rel_multihead_attn, the later seems to make better sense.

PyTorch multiGPU training

Thank you for releasing such an awesome and easy to use code!

Could you, please, elaborate a little bit on PyTorch implementation multi GPU setup? More concretely, what does the parameter "gpu0_bsz" mean and what parameters should I change to scale this code to setups with the number of GPUs more (or less) than 4?

From the description it seems that "gpu0_bsz" is the batch size for GPU0, but it is not clear to me why it should differ from batch sizes on other GPUs.

Generation script

Can you include a simple script for generating text with a pretrained Transformer-XL language model? We are primarily using the PyTorch codebase but I am sure Tensorflow users would also appreciate this example.

If including this script is outside the scope of the project repository, could an informal example be provided in this issue thread?

Unable to replicate experiment results

I am not sure where it is wrong but I have been training enwik8 with your TensorFlow code and default parameters on 4 GPUs for 4 days and the loss never drops below 4.2 meanwhile the learning rate already drops to 0.000001. Is there any special tricks to replicate the experiment?

Thanks.

P.S. I am using Python 3 and TensorFlow 1.11.0. I have not tried on the other 3 datasets yet. I also tried transformer-xl on a private dataset (where a single-layer word-level LSTM can achieve around 60%+ accuracy), and its loss also never drops below 4.2 and accuracy never goes higher than 15%.

Finetune with transformer-xl pretrained models

Hi, thanks for your excellent work. Transformer-xl is the most elegant model for long sequences by now. Do you plan to finetune pretrained models for document classification just like Bert?

Different ppl values for same inputs

Hi,
I observed values to be slighty different when evaluating the perplexity of set of sentences with batch_size = 1 vs looping through the sentences one by one. (all other parameters being same)
difference in loss is 0.7707 vs 0.7564 in other

I created the data iterator using dataset="lm1b"
Note: I modified the corpus.vocab.encode_file to encode the input sentence instead of reading from file
Any particular reason why this is observed.

Google one-billion experiments

with config: python train.py --cuda --data ../data/one-billion-words/ --dataset lm1b --adaptive --n_layer 18 --d_model 1024 --div_val 4 --n_head 8 --d_head 128 --d_inner 4096 --dropout 0.0 --dropatt 0.0 --optim adam --log-interval 5 --eval-interval 20 --warmup_step 20000 --max_step 500000 --lr 0.00025 --tgt_len 32 --mem_len 32 --eval_tgt_len 32 --batch_size 240 --batch_chunk 8 --work_dir exps, so run on only one GPU, do you think possible to achieve similar results as in the paper?

License for the released code?

Love the work and really glad to see getdata.sh was useful and that you extended it! Wrangling datasets is never the fun part ;)

Can you clarify the license for the released code by adding a LICENSE file?

Tensor2Tensor compatibility

Thank you for such easy to read code & repo - can be seen that a lot of hard work has gone into it! Secondly, found your work from Sebastian Ruder NLP newsletter and as he put it as: "Peer review is an imprecise process and gems may sometimes fall through the cracks." Your work was under one of the gems and I totally agree!

Now specifically, I tried using wt103 in Tensor2Tensor and I'm getting an error of:

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key transformer/body/decoder/layer_0/ffn/conv1/bias not found in checkpoint
	 [[node save/RestoreV2_1 (defined at /home/ubuntu/tensor2tensor/venv/lib/python3.5/site-packages/tensor2tensor/utils/decoding.py:586)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]

I suppose it comes from the wrong hparams I am using?

@registry.register_hparams
def transformer_xl():
  """Hparams for transformer-xl"""
  hparams = transformer.transformer_base()
  hparams.batch_size = 2048
  hparams.hidden_size = 4096
  hparams.filter_size = 3072
  hparams.num_hidden_layers = 18
  hparams.num_heads = 16
  hparams.max_length = 1024
  hparams.eval_drop_long_sequences = True
  return hparams

Tensor2Tensor transformer hparams

It seems the eval speed of transformer-xl is not faster than bert-base-uncased.

I use the code here: https://github.com/huggingface/pytorch-pretrained-BERT


I run run_classifier.py with bert-base-uncased and max_seq_length=128 on the MRPC task.
The log:

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
02/21/2019 12:11:44 - INFO - __main__ -   device: cpu n_gpu: 1, distributed training: False, 16-bits training: False
02/21/2019 12:11:45 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/tong.guo/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
02/21/2019 12:11:45 - INFO - pytorch_pretrained_bert.modeling -   loading archive file ../model_file/bert-base-uncased.tar.gz
02/21/2019 12:11:45 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ../model_file/bert-base-uncased.tar.gz to temp dir /tmp/tmpaho9_3dk
02/21/2019 12:11:50 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

02/21/2019 12:11:55 - INFO - pytorch_pretrained_bert.modeling -   Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
02/21/2019 12:11:55 - INFO - pytorch_pretrained_bert.modeling -   Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
02/21/2019 12:11:55 - INFO - pytorch_pretrained_bert.modeling -   loading archive file ../model_file/bert-base-uncased.tar.gz
02/21/2019 12:11:55 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ../model_file/bert-base-uncased.tar.gz to temp dir /tmp/tmpfehb71wu
02/21/2019 12:11:59 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

02/21/2019 12:12:03 - INFO - pytorch_pretrained_bert.modeling -   Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
02/21/2019 12:12:03 - INFO - pytorch_pretrained_bert.modeling -   Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
02/21/2019 12:12:03 - INFO - __main__ -   *** Example ***
02/21/2019 12:12:03 - INFO - __main__ -   guid: dev-1
02/21/2019 12:12:03 - INFO - __main__ -   tokens: [CLS] [UNK] ' s chief operating officer , [UNK] [UNK] , and [UNK] [UNK] , the chief financial officer , will report directly to [UNK] [UNK] . [SEP] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] and [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] will report to [UNK] . [SEP]
02/21/2019 12:12:03 - INFO - __main__ -   input_ids: 101 100 1005 1055 2708 4082 2961 1010 100 100 1010 1998 100 100 1010 1996 2708 3361 2961 1010 2097 3189 3495 2000 100 100 1012 102 100 100 100 100 100 100 1998 100 100 100 100 100 100 2097 3189 2000 100 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   label: 1 (id = 1)
02/21/2019 12:12:03 - INFO - __main__ -   *** Example ***
02/21/2019 12:12:03 - INFO - __main__ -   guid: dev-2
02/21/2019 12:12:03 - INFO - __main__ -   tokens: [CLS] [UNK] world ' s two largest auto ##makers said their [UNK] . [UNK] . sales declined more than predicted last month as a late summer sales frenzy caused more of an industry backlash than expected . [SEP] [UNK] sales at both [UNK] and [UNK] . 2 [UNK] [UNK] [UNK] . declined more than predicted as a late summer sales frenzy prompted a larger - than - expected industry backlash . [SEP]
02/21/2019 12:12:03 - INFO - __main__ -   input_ids: 101 100 2088 1005 1055 2048 2922 8285 12088 2056 2037 100 1012 100 1012 4341 6430 2062 2084 10173 2197 3204 2004 1037 2397 2621 4341 21517 3303 2062 1997 2019 3068 25748 2084 3517 1012 102 100 4341 2012 2119 100 1998 100 1012 1016 100 100 100 1012 6430 2062 2084 10173 2004 1037 2397 2621 4341 21517 9469 1037 3469 1011 2084 1011 3517 3068 25748 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   label: 1 (id = 1)
02/21/2019 12:12:03 - INFO - __main__ -   *** Example ***
02/21/2019 12:12:03 - INFO - __main__ -   guid: dev-3
02/21/2019 12:12:03 - INFO - __main__ -   tokens: [CLS] [UNK] to the federal [UNK] for [UNK] [UNK] and [UNK] ( news - web sites ) , there were 19 reported cases of me ##as ##les in the [UNK] [UNK] in 2002 . [SEP] [UNK] [UNK] for [UNK] [UNK] and [UNK] said there were 19 reported cases of me ##as ##les in the [UNK] [UNK] in 2002 . [SEP]
02/21/2019 12:12:03 - INFO - __main__ -   input_ids: 101 100 2000 1996 2976 100 2005 100 100 1998 100 1006 2739 1011 4773 4573 1007 1010 2045 2020 2539 2988 3572 1997 2033 3022 4244 1999 1996 100 100 1999 2526 1012 102 100 100 2005 100 100 1998 100 2056 2045 2020 2539 2988 3572 1997 2033 3022 4244 1999 1996 100 100 1999 2526 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   label: 1 (id = 1)
02/21/2019 12:12:03 - INFO - __main__ -   *** Example ***
02/21/2019 12:12:03 - INFO - __main__ -   guid: dev-4
02/21/2019 12:12:03 - INFO - __main__ -   tokens: [CLS] [UNK] tropical storm rapidly developed in the [UNK] of [UNK] [UNK] and was expected to hit somewhere along the [UNK] or [UNK] coasts by [UNK] night . [SEP] [UNK] tropical storm rapidly developed in the [UNK] of [UNK] on [UNK] and could have hurricane - force winds when it hits land somewhere along the [UNK] coast [UNK] night . [SEP]
02/21/2019 12:12:03 - INFO - __main__ -   input_ids: 101 100 5133 4040 5901 2764 1999 1996 100 1997 100 100 1998 2001 3517 2000 2718 4873 2247 1996 100 2030 100 20266 2011 100 2305 1012 102 100 5133 4040 5901 2764 1999 1996 100 1997 100 2006 100 1998 2071 2031 7064 1011 2486 7266 2043 2009 4978 2455 4873 2247 1996 100 3023 100 2305 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   label: 0 (id = 0)
02/21/2019 12:12:03 - INFO - __main__ -   *** Example ***
02/21/2019 12:12:03 - INFO - __main__ -   guid: dev-5
02/21/2019 12:12:03 - INFO - __main__ -   tokens: [CLS] [UNK] company didn ' t detail the costs of the replacement and repairs . [SEP] [UNK] company officials expect the costs of the replacement work to run into the millions of dollars . [SEP]
02/21/2019 12:12:03 - INFO - __main__ -   input_ids: 101 100 2194 2134 1005 1056 6987 1996 5366 1997 1996 6110 1998 10315 1012 102 100 2194 4584 5987 1996 5366 1997 1996 6110 2147 2000 2448 2046 1996 8817 1997 6363 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
02/21/2019 12:12:03 - INFO - __main__ -   label: 0 (id = 0)
02/21/2019 12:12:04 - INFO - __main__ -   ***** Running evaluation *****
02/21/2019 12:12:04 - INFO - __main__ -     Num examples = 1725
02/21/2019 12:12:04 - INFO - __main__ -     Batch size = 8
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 216/216 [06:06<00:00,  1.70s/it]
02/21/2019 12:18:11 - INFO - __main__ -   ***** Eval results *****
02/21/2019 12:18:11 - INFO - __main__ -     eval_accuracy = 0.33507246376811595
02/21/2019 12:18:11 - INFO - __main__ -     eval_loss = 1.002936492777533
02/21/2019 12:18:11 - INFO - __main__ -     global_step = 0
02/21/2019 12:18:11 - INFO - __main__ -     loss = Non

The speed is about 1.7 s/batch


I run run_transfo_xl.py on the wikitext-103 task.
The log:

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
02/20/2019 19:49:30 - INFO - __main__ -   device: cuda
02/20/2019 19:49:30 - INFO - pytorch_pretrained_bert.tokenization_transfo_xl -   loading vocabulary file ../model_file/transfo-xl-wt103-vocab.bin
02/20/2019 19:49:30 - INFO - pytorch_pretrained_bert.tokenization_transfo_xl -   loading corpus file ../model_file/transfo-xl-wt103-corpus.bin
02/20/2019 19:49:36 - INFO - pytorch_pretrained_bert.modeling_transfo_xl -   loading weights file ../model_file/transfo-xl-wt103-pytorch_model.bin
02/20/2019 19:49:36 - INFO - pytorch_pretrained_bert.modeling_transfo_xl -   loading configuration file ../model_file/transfo-xl-wt103-config.json
02/20/2019 19:49:36 - INFO - pytorch_pretrained_bert.modeling_transfo_xl -   Model config {
  "adaptive": true,
  "attn_type": 0,
  "clamp_len": 1000,
  "cutoffs": [
    20000,
    40000,
    200000
  ],
  "d_embed": 1024,
  "d_head": 64,
  "d_inner": 4096,
  "d_model": 1024,
  "div_val": 4,
  "dropatt": 0.0,
  "dropout": 0.1,
  "ext_len": 0,
  "init": "normal",
  "init_range": 0.01,
  "init_std": 0.02,
  "mem_len": 1600,
  "n_head": 16,
  "n_layer": 18,
  "n_token": 267735,
  "pre_lnorm": false,
  "proj_init_std": 0.01,
  "same_length": true,
  "sample_softmax": -1,
  "tgt_len": 128,
  "tie_projs": [
    false,
    true,
    true,
    true
  ],
  "tie_weight": true,
  "untie_r": true
}

02/20/2019 19:49:51 - INFO - __main__ -   Evaluating with bsz 10 tgt_len 128 ext_len 0 mem_len 1600 clamp_len 1000
02/20/2019 19:57:35 - INFO - __main__ -   Time : 464.00s, 2416.66ms/segment
02/20/2019 19:57:35 - INFO - __main__ -   ====================================================================================================
02/20/2019 19:57:35 - INFO - __main__ -   | test loss  2.90 | test ppl    18.213 
02/20/2019 19:57:35 - INFO - __main__ -   ====================================================================================================

The speed is about 2.4 s/batch

trouble loading pytorch model

Hello,

Thanks for the pytorch version of transformer-XL. I trained a model on my own corpus and it ran smoothly but I can't seem to load the model back from the checkpoint. I tried loading it the same way as in the eval.py script. Printing the model gives an attribute error.

model = torch.load('model.pt')
print(model)
AttributeError: 'NoneType' object has no attribute 'size'

Usage of the additional parameter ext_len

Hi,
Thanks for this great piece of work (research and code), it's very impressive!
I am wondering why the PyTorch version has the additional parameter ext_len which doesn't seems to be used in the TensorFlow version.

problem in tf code

when I read your tf code,i am really confused about the codes below?

rw_head_q = w_head_q + r_w_bias
rr_head_q = w_head_q + r_r_bias

AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
BD = rel_shift(BD)

could you tell me what are the variables r_w_bias and r_r_bias mean in your paper?
and could you explain about this code shortly?
Thanks for your effort

Quick question on comparison against BERT

Thanks for the codes! I am sure my question will be asked over and over and over again in near future. And I also read your paper which is all about comparison against vanilla transformer.

But still, in terms of performance, have you compared your great model against BERT? I mean it may not be a 100% fair comparison. But at the end of the day... which one (BERT or Tranformer-XL) is better on typical NLP tasks? Thanks.

Does the "Beyond Fixed-Length" solution make the new architecture compatible with incomplete-sentence tasks (e.g. text generation)?

In the past couple of months I've been trying to get a vanilla Transformer to be able to do text generation, more specifically to generate a long sentence starting from a small prime sentence.

This has thus far failed, and in multiple discussions it was pointed out that a Transformer only works on fixed-length context, while text generation is an incomplete-sentence task, therefor it is not fixed-length.

In your educated guess, is there merit in trying to fit this new architecture to the task of text generation now that the fixed-length problem has been solved?

parameters in tf code

Hi,
Does anyone know about the function of parameters 'bin_sizes' and 'cutoffs' used for lm1b model?
Thanks for your help

Some problems when training

Thank you for your transformer code!
When I ran the code, I encountered such issue:

Traceback (most recent call last):
File "train.py", line 539, in
train()
File "train.py", line 451, in train
loss.backward()
File "/data1/baiye/miniconda3/envs/torch04py3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/data1/baiye/miniconda3/envs/torch04py3/lib/python3.6/site-packages/torch/autograd/init.py", line 89, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

My enviroment is pytorch 0.4. And I checked the code and did not find any inplace operation.

relative attention score

hello,I have read this paper but I'm confused at the relative attention score. I don't know why the transpose(u) and the transpose(v) were defined separately as they seemed have the same meaning. Is there anything that I should consider ?

sparse updates with multi-GPU

Hello, very nice work, and thank you for sharing the source.

I was looking at the PyTorch implementation and I was wondering how you are able to make the multi-GPU work with sparse updates, especially when fp16 is activated. Because none of the sparse/fp16 or sparse/distributed are currently implemented in PyTorch. My feeling is that in the current code, you have an optimizer that synchronizes the parameters across GPUs as expected, but the sparse updates are never synchronized which should result in slightly different models in each process. Or maybe I am missing something?

Thank you

Some questions about pytorch code and details.

Hi, there:
So nice that you release the original code. Maybe a little difficult for me to reproduce: (
After nearly 1.5 days for matching your paper and code, still... some questions about model structure, hope you could help, maybe some foolish ...

  1. What's the difference between RelLearnableMultiHeadAttn and RelPartialLearnableMultiHeadAttn ?
    Seem the most important part is the construction of embedding (A+B+C+D), but the first one doesn't use the position embedding in "Attention is all you need"?

  2. Can you explain the function _rel_shift in detail for me?
    Especially the top -4 line code, I don't know why we need this?

  3. What happens when the param div_val > 1 and what's the meaning of the cutoff_xxx?
    More specifically, I think what we need is the part of code when div_val==1.

Hope you could help me, thx.

How to train models with attn_type=2 on wiki103 training set?

I want to train a model with attn_type=2, and here is my configure.
Experiment dir : wt103_workdir/-wt103/20190121-201645
Loading cached dataset...

- data : ../data/wikitext-103/
- dataset : wt103
- n_layer : 16
- n_head : 10
- d_head : 41
- d_embed : 410
- d_model : 410
- d_inner : 2100
- dropout : 0.1
- dropatt : 0.0
- init : normal
- emb_init : normal
- init_range : 0.1
- emb_init_range : 0.01
- init_std : 0.02
- proj_init_std : 0.01
- optim : adam
- lr : 0.00025
- mom : 0.0
- scheduler : cosine
- warmup_step : 0
- decay_rate : 0.5
- lr_min : 0.0
- clip : 0.25
- clip_nonemb : False
- max_step : 200000
- batch_size : 60
- batch_chunk : 1
- tgt_len : 150
- eval_tgt_len : 150
- ext_len : 0
- mem_len : 0
- not_tied : False
- seed : 1111
- cuda : True
- adaptive : True
- div_val : 1
- pre_lnorm : False
- varlen : False
- multi_gpu : True
- log_interval : 200
- eval_interval : 4000
- work_dir : wt103_workdir/-wt103/20190121-201645
- restart : False
- restart_dir :
- debug : False
- same_length : False
- attn_type : 2
- clamp_len : -1
- eta_min : 0.0
- gpu0_bsz : 4
- max_eval_steps : -1
- sample_softmax : -1
- patience : 0
- finetune_v2 : False
- finetune_v3 : False
- fp16 : False
- static_loss_scale : 1
- dynamic_loss_scale : False
- tied : True
- n_token : 267735
- n_all_param : 148417118
- n_nonemb_param : 38376800

But it seems to diverge. Can anyone give me some advice? Thanks very much!
| epoch 1 step 200 | 200 batches | lr 0.00025 | ms/batch 702.80 | loss 7.64 | ppl 2088.928
| epoch 1 step 400 | 400 batches | lr 0.00025 | ms/batch 621.46 | loss 7.46 | ppl 1730.912
| epoch 1 step 600 | 600 batches | lr 0.00025 | ms/batch 621.39 | loss 7.45 | ppl 1728.002
| epoch 1 step 800 | 800 batches | lr 0.00025 | ms/batch 621.19 | loss 7.45 | ppl 1717.449
| epoch 1 step 1000 | 1000 batches | lr 0.00025 | ms/batch 621.40 | loss 7.45 | ppl 1720.351
| epoch 1 step 1200 | 1200 batches | lr 0.00025 | ms/batch 620.97 | loss 7.44 | ppl 1700.753
| epoch 1 step 1400 | 1400 batches | lr 0.00025 | ms/batch 620.69 | loss 7.44 | ppl 1695.095
| epoch 1 step 1600 | 1600 batches | lr 0.00025 | ms/batch 635.92 | loss 7.44 | ppl 1711.255
| epoch 1 step 1800 | 1800 batches | lr 0.00025 | ms/batch 620.47 | loss 7.44 | ppl 1710.397
| epoch 1 step 2000 | 2000 batches | lr 0.00025 | ms/batch 619.97 | loss 7.44 | ppl 1695.020
| epoch 1 step 2200 | 2200 batches | lr 0.00025 | ms/batch 620.45 | loss 7.43 | ppl 1690.592
| epoch 1 step 2400 | 2400 batches | lr 0.00025 | ms/batch 620.02 | loss 7.44 | ppl 1702.485
| epoch 1 step 2600 | 2600 batches | lr 0.00025 | ms/batch 620.64 | loss 7.43 | ppl 1689.785
| epoch 1 step 2800 | 2800 batches | lr 0.00025 | ms/batch 619.92 | loss 7.43 | ppl 1693.790
| epoch 1 step 3000 | 3000 batches | lr 0.00025 | ms/batch 620.23 | loss 7.43 | ppl 1684.638
| epoch 1 step 3200 | 3200 batches | lr 0.00025 | ms/batch 619.65 | loss 7.42 | ppl 1666.079
| epoch 1 step 3400 | 3400 batches | lr 0.00025 | ms/batch 619.42 | loss 7.41 | ppl 1659.356
| epoch 1 step 3600 | 3600 batches | lr 0.00025 | ms/batch 619.80 | loss 7.41 | ppl 1649.053
| epoch 1 step 3800 | 3800 batches | lr 0.00025 | ms/batch 620.29 | loss 7.45 | ppl 1711.666
| epoch 1 step 4000 | 4000 batches | lr 0.00025 | ms/batch 620.61 | loss 7.42 | ppl 1661.076

| Eval 1 at step 4000 | time: 2506.77s | valid loss 7.41 | valid ppl 1654.046

| epoch 1 step 4200 | 4200 batches | lr 0.00025 | ms/batch 662.27 | loss 7.43 | ppl 1682.745
| epoch 1 step 4400 | 4400 batches | lr 0.00025 | ms/batch 619.80 | loss 7.42 | ppl 1672.333
| epoch 1 step 4600 | 4600 batches | lr 0.00025 | ms/batch 619.94 | loss 7.42 | ppl 1668.553
| epoch 1 step 4800 | 4800 batches | lr 0.00025 | ms/batch 619.89 | loss 7.42 | ppl 1669.587
| epoch 1 step 5000 | 5000 batches | lr 0.00025 | ms/batch 620.04 | loss 7.44 | ppl 1705.483
| epoch 1 step 5200 | 5200 batches | lr 0.00025 | ms/batch 619.59 | loss 7.44 | ppl 1707.168
| epoch 1 step 5400 | 5400 batches | lr 0.00025 | ms/batch 619.41 | loss 7.41 | ppl 1656.997
| epoch 1 step 5600 | 5600 batches | lr 0.00025 | ms/batch 619.81 | loss 7.43 | ppl 1682.111
| epoch 1 step 5800 | 5800 batches | lr 0.000249 | ms/batch 620.20 | loss 7.44 | ppl 1695.797
| epoch 1 step 6000 | 6000 batches | lr 0.000249 | ms/batch 619.68 | loss 7.43 | ppl 1691.197
| epoch 1 step 6200 | 6200 batches | lr 0.000249 | ms/batch 619.43 | loss 7.41 | ppl 1654.504
| epoch 1 step 6400 | 6400 batches | lr 0.000249 | ms/batch 620.16 | loss 7.43 | ppl 1688.890
| epoch 1 step 6600 | 6600 batches | lr 0.000249 | ms/batch 620.38 | loss 7.43 | ppl 1678.386
| epoch 1 step 6800 | 6800 batches | lr 0.000249 | ms/batch 619.64 | loss 7.42 | ppl 1664.799
| epoch 1 step 7000 | 7000 batches | lr 0.000249 | ms/batch 619.72 | loss 7.42 | ppl 1670.465
| epoch 1 step 7200 | 7200 batches | lr 0.000249 | ms/batch 620.02 | loss 7.42 | ppl 1667.845
| epoch 1 step 7400 | 7400 batches | lr 0.000249 | ms/batch 619.56 | loss 7.42 | ppl 1670.860
| epoch 1 step 7600 | 7600 batches | lr 0.000249 | ms/batch 620.25 | loss 7.42 | ppl 1664.089
| epoch 1 step 7800 | 7800 batches | lr 0.000249 | ms/batch 619.83 | loss 7.42 | ppl 1661.560
| epoch 1 step 8000 | 8000 batches | lr 0.000249 | ms/batch 619.85 | loss 7.43 | ppl 1689.380

| Eval 2 at step 8000 | time: 2484.77s | valid loss 7.39 | valid ppl 1616.071

| epoch 1 step 8200 | 8200 batches | lr 0.000249 | ms/batch 672.76 | loss 7.43 | ppl 1680.628
| epoch 1 step 8400 | 8400 batches | lr 0.000249 | ms/batch 620.21 | loss 7.43 | ppl 1685.528
| epoch 1 step 8600 | 8600 batches | lr 0.000249 | ms/batch 619.91 | loss 7.43 | ppl 1684.851
| epoch 1 step 8800 | 8800 batches | lr 0.000249 | ms/batch 620.02 | loss 7.44 | ppl 1699.004
| epoch 1 step 9000 | 9000 batches | lr 0.000249 | ms/batch 619.53 | loss 7.42 | ppl 1667.265
| epoch 1 step 9200 | 9200 batches | lr 0.000249 | ms/batch 620.79 | loss 7.43 | ppl 1684.868
| epoch 1 step 9400 | 9400 batches | lr 0.000249 | ms/batch 620.06 | loss 7.42 | ppl 1672.693
| epoch 1 step 9600 | 9600 batches | lr 0.000249 | ms/batch 619.62 | loss 7.43 | ppl 1689.861
| epoch 1 step 9800 | 9800 batches | lr 0.000249 | ms/batch 619.44 | loss 7.41 | ppl 1652.922
| epoch 1 step 10000 | 10000 batches | lr 0.000248 | ms/batch 620.06 | loss 7.43 | ppl 1692.675
| epoch 1 step 10200 | 10200 batches | lr 0.000248 | ms/batch 619.46 | loss 7.41 | ppl 1653.468
| epoch 1 step 10400 | 10400 batches | lr 0.000248 | ms/batch 619.62 | loss 7.41 | ppl 1651.442
| epoch 1 step 10600 | 10600 batches | lr 0.000248 | ms/batch 620.05 | loss 7.41 | ppl 1652.406
| epoch 1 step 10800 | 10800 batches | lr 0.000248 | ms/batch 619.74 | loss 7.41 | ppl 1658.664
| epoch 1 step 11000 | 11000 batches | lr 0.000248 | ms/batch 619.59 | loss 7.44 | ppl 1694.259
| epoch 1 step 11200 | 11200 batches | lr 0.000248 | ms/batch 619.55 | loss 7.42 | ppl 1672.915
| epoch 1 step 11400 | 11400 batches | lr 0.000248 | ms/batch 619.08 | loss 7.42 | ppl 1664.737
| epoch 2 step 11600 | 130 batches | lr 0.000248 | ms/batch 620.83 | loss 7.38 | ppl 1601.478
| epoch 2 step 11800 | 330 batches | lr 0.000248 | ms/batch 621.36 | loss 7.31 | ppl 1490.017
| epoch 2 step 12000 | 530 batches | lr 0.000248 | ms/batch 621.48 | loss 7.33 | ppl 1523.110

| Eval 3 at step 12000 | time: 2485.37s | valid loss 7.42 | valid ppl 1674.380

| epoch 2 step 12200 | 730 batches | lr 0.000248 | ms/batch 648.06 | loss 7.31 | ppl 1498.382
| epoch 2 step 12400 | 930 batches | lr 0.000248 | ms/batch 621.16 | loss 7.33 | ppl 1527.507
| epoch 2 step 12600 | 1130 batches | lr 0.000248 | ms/batch 621.00 | loss 7.33 | ppl 1530.506
| epoch 2 step 12800 | 1330 batches | lr 0.000247 | ms/batch 620.95 | loss 7.33 | ppl 1525.309
| epoch 2 step 13000 | 1530 batches | lr 0.000247 | ms/batch 621.32 | loss 7.33 | ppl 1527.929
| epoch 2 step 13200 | 1730 batches | lr 0.000247 | ms/batch 621.35 | loss 7.34 | ppl 1543.376
| epoch 2 step 13400 | 1930 batches | lr 0.000247 | ms/batch 621.04 | loss 7.33 | ppl 1523.908
| epoch 2 step 13600 | 2130 batches | lr 0.000247 | ms/batch 621.29 | loss 7.34 | ppl 1546.512
| epoch 2 step 13800 | 2330 batches | lr 0.000247 | ms/batch 620.99 | loss 7.34 | ppl 1545.263
| epoch 2 step 14000 | 2530 batches | lr 0.000247 | ms/batch 621.08 | loss 7.34 | ppl 1540.291
| epoch 2 step 14200 | 2730 batches | lr 0.000247 | ms/batch 620.93 | loss 7.34 | ppl 1540.285
| epoch 2 step 14400 | 2930 batches | lr 0.000247 | ms/batch 621.68 | loss 7.34 | ppl 1540.759
| epoch 2 step 14600 | 3130 batches | lr 0.000247 | ms/batch 621.22 | loss 7.32 | ppl 1512.795
| epoch 2 step 14800 | 3330 batches | lr 0.000247 | ms/batch 621.04 | loss 7.32 | ppl 1506.678
| epoch 2 step 15000 | 3530 batches | lr 0.000247 | ms/batch 621.31 | loss 7.33 | ppl 1530.028
| epoch 2 step 15200 | 3730 batches | lr 0.000246 | ms/batch 621.44 | loss 7.34 | ppl 1537.768
| epoch 2 step 15400 | 3930 batches | lr 0.000246 | ms/batch 621.56 | loss 7.33 | ppl 1532.047
| epoch 2 step 15600 | 4130 batches | lr 0.000246 | ms/batch 622.21 | loss 7.34 | ppl 1535.568
| epoch 2 step 15800 | 4330 batches | lr 0.000246 | ms/batch 621.75 | loss 7.34 | ppl 1537.776
| epoch 2 step 16000 | 4530 batches | lr 0.000246 | ms/batch 621.52 | loss 7.33 | ppl 1524.707

| Eval 4 at step 16000 | time: 2490.61s | valid loss 7.42 | valid ppl 1664.061

| epoch 2 step 16200 | 4730 batches | lr 0.000246 | ms/batch 648.01 | loss 7.33 | ppl 1531.670
| epoch 2 step 16400 | 4930 batches | lr 0.000246 | ms/batch 621.81 | loss 7.35 | ppl 1561.311
| epoch 2 step 16600 | 5130 batches | lr 0.000246 | ms/batch 621.62 | loss 7.35 | ppl 1558.448
| epoch 2 step 16800 | 5330 batches | lr 0.000246 | ms/batch 621.31 | loss 7.34 | ppl 1544.213
| epoch 2 step 17000 | 5530 batches | lr 0.000246 | ms/batch 621.27 | loss 7.33 | ppl 1521.180
| epoch 2 step 17200 | 5730 batches | lr 0.000245 | ms/batch 621.11 | loss 7.36 | ppl 1577.214
| epoch 2 step 17400 | 5930 batches | lr 0.000245 | ms/batch 620.95 | loss 7.35 | ppl 1551.961
| epoch 2 step 17600 | 6130 batches | lr 0.000245 | ms/batch 620.91 | loss 7.34 | ppl 1546.448
| epoch 2 step 17800 | 6330 batches | lr 0.000245 | ms/batch 621.03 | loss 7.34 | ppl 1534.776
| epoch 2 step 18000 | 6530 batches | lr 0.000245 | ms/batch 621.68 | loss 7.36 | ppl 1571.506
| epoch 2 step 18200 | 6730 batches | lr 0.000245 | ms/batch 621.27 | loss 7.34 | ppl 1535.712
| epoch 2 step 18400 | 6930 batches | lr 0.000245 | ms/batch 621.65 | loss 7.34 | ppl 1538.308
| epoch 2 step 18600 | 7130 batches | lr 0.000245 | ms/batch 620.88 | loss 7.34 | ppl 1541.480
| epoch 2 step 18800 | 7330 batches | lr 0.000245 | ms/batch 621.06 | loss 7.34 | ppl 1539.062
| epoch 2 step 19000 | 7530 batches | lr 0.000244 | ms/batch 621.02 | loss 7.35 | ppl 1556.423
| epoch 2 step 19200 | 7730 batches | lr 0.000244 | ms/batch 621.01 | loss 7.33 | ppl 1530.237
| epoch 2 step 19400 | 7930 batches | lr 0.000244 | ms/batch 621.38 | loss 7.35 | ppl 1560.169
| epoch 2 step 19600 | 8130 batches | lr 0.000244 | ms/batch 621.06 | loss 7.34 | ppl 1543.635
| epoch 2 step 19800 | 8330 batches | lr 0.000244 | ms/batch 621.31 | loss 7.34 | ppl 1546.412
| epoch 2 step 20000 | 8530 batches | lr 0.000244 | ms/batch 620.97 | loss 7.36 | ppl 1573.621

| Eval 5 at step 20000 | time: 2490.28s | valid loss 7.40 | valid ppl 1642.454

| epoch 2 step 20200 | 8730 batches | lr 0.000244 | ms/batch 648.18 | loss 7.35 | ppl 1552.783
| epoch 2 step 20400 | 8930 batches | lr 0.000244 | ms/batch 621.19 | loss 7.35 | ppl 1561.782
| epoch 2 step 20600 | 9130 batches | lr 0.000244 | ms/batch 621.56 | loss 7.35 | ppl 1551.505
| epoch 2 step 20800 | 9330 batches | lr 0.000243 | ms/batch 621.11 | loss 7.35 | ppl 1550.757
| epoch 2 step 21000 | 9530 batches | lr 0.000243 | ms/batch 621.23 | loss 7.36 | ppl 1576.482
| epoch 2 step 21200 | 9730 batches | lr 0.000243 | ms/batch 621.01 | loss 7.34 | ppl 1534.469
| epoch 2 step 21400 | 9930 batches | lr 0.000243 | ms/batch 621.09 | loss 7.35 | ppl 1552.626
| epoch 2 step 21600 | 10130 batches | lr 0.000243 | ms/batch 621.28 | loss 7.35 | ppl 1550.348
| epoch 2 step 21800 | 10330 batches | lr 0.000243 | ms/batch 621.55 | loss 7.35 | ppl 1555.845
| epoch 2 step 22000 | 10530 batches | lr 0.000243 | ms/batch 620.96 | loss 7.34 | ppl 1533.085
| epoch 2 step 22200 | 10730 batches | lr 0.000242 | ms/batch 620.96 | loss 7.35 | ppl 1556.160
| epoch 2 step 22400 | 10930 batches | lr 0.000242 | ms/batch 621.35 | loss 7.35 | ppl 1562.793
| epoch 2 step 22600 | 11130 batches | lr 0.000242 | ms/batch 620.82 | loss 7.35 | ppl 1563.720
| epoch 2 step 22800 | 11330 batches | lr 0.000242 | ms/batch 621.20 | loss 7.36 | ppl 1566.230
| epoch 3 step 23000 | 60 batches | lr 0.000242 | ms/batch 620.74 | loss 7.34 | ppl 1541.840
| epoch 3 step 23200 | 260 batches | lr 0.000242 | ms/batch 621.55 | loss 7.28 | ppl 1453.898
| epoch 3 step 23400 | 460 batches | lr 0.000242 | ms/batch 622.00 | loss 7.30 | ppl 1479.356
| epoch 3 step 23600 | 660 batches | lr 0.000242 | ms/batch 621.62 | loss 7.29 | ppl 1458.763
| epoch 3 step 23800 | 860 batches | lr 0.000241 | ms/batch 621.74 | loss 7.31 | ppl 1488.324
| epoch 3 step 24000 | 1060 batches | lr 0.000241 | ms/batch 622.03 | loss 7.30 | ppl 1476.624

| Eval 6 at step 24000 | time: 2490.67s | valid loss 7.44 | valid ppl 1703.611

| epoch 3 step 24200 | 1260 batches | lr 0.000241 | ms/batch 648.34 | loss 7.30 | ppl 1478.781
| epoch 3 step 24400 | 1460 batches | lr 0.000241 | ms/batch 621.87 | loss 7.30 | ppl 1476.575
| epoch 3 step 24600 | 1660 batches | lr 0.000241 | ms/batch 621.75 | loss 7.31 | ppl 1499.300
| epoch 3 step 24800 | 1860 batches | lr 0.000241 | ms/batch 621.83 | loss 7.30 | ppl 1477.016
| epoch 3 step 25000 | 2060 batches | lr 0.00024 | ms/batch 622.08 | loss 7.31 | ppl 1500.889
| epoch 3 step 25200 | 2260 batches | lr 0.00024 | ms/batch 621.62 | loss 7.31 | ppl 1495.962
| epoch 3 step 25400 | 2460 batches | lr 0.00024 | ms/batch 621.89 | loss 7.31 | ppl 1492.161
| epoch 3 step 25600 | 2660 batches | lr 0.00024 | ms/batch 621.87 | loss 7.31 | ppl 1492.371
| epoch 3 step 25800 | 2860 batches | lr 0.00024 | ms/batch 621.40 | loss 7.31 | ppl 1491.645
| epoch 3 step 26000 | 3060 batches | lr 0.00024 | ms/batch 621.84 | loss 7.30 | ppl 1484.346
| epoch 3 step 26200 | 3260 batches | lr 0.00024 | ms/batch 621.87 | loss 7.29 | ppl 1466.224
| epoch 3 step 26400 | 3460 batches | lr 0.000239 | ms/batch 621.72 | loss 7.29 | ppl 1471.563
| epoch 3 step 26600 | 3660 batches | lr 0.000239 | ms/batch 621.46 | loss 7.30 | ppl 1484.499
| epoch 3 step 26800 | 3860 batches | lr 0.000239 | ms/batch 621.93 | loss 7.31 | ppl 1492.562
| epoch 3 step 27000 | 4060 batches | lr 0.000239 | ms/batch 621.78 | loss 7.30 | ppl 1474.048
| epoch 3 step 27200 | 4260 batches | lr 0.000239 | ms/batch 621.84 | loss 7.31 | ppl 1498.078
| epoch 3 step 27400 | 4460 batches | lr 0.000239 | ms/batch 621.81 | loss 7.30 | ppl 1477.633
| epoch 3 step 27600 | 4660 batches | lr 0.000238 | ms/batch 621.69 | loss 7.31 | ppl 1491.983
| epoch 3 step 27800 | 4860 batches | lr 0.000238 | ms/batch 621.52 | loss 7.32 | ppl 1507.707
| epoch 3 step 28000 | 5060 batches | lr 0.000238 | ms/batch 621.40 | loss 7.32 | ppl 1507.557

| Eval 7 at step 28000 | time: 2492.27s | valid loss 7.44 | valid ppl 1706.144

| epoch 3 step 28200 | 5260 batches | lr 0.000238 | ms/batch 648.62 | loss 7.32 | ppl 1502.753
| epoch 3 step 28400 | 5460 batches | lr 0.000238 | ms/batch 621.67 | loss 7.31 | ppl 1487.829
| epoch 3 step 28600 | 5660 batches | lr 0.000238 | ms/batch 621.82 | loss 7.33 | ppl 1518.292
| epoch 3 step 28800 | 5860 batches | lr 0.000237 | ms/batch 621.37 | loss 7.32 | ppl 1510.123
| epoch 3 step 29000 | 6060 batches | lr 0.000237 | ms/batch 621.82 | loss 7.31 | ppl 1501.527
| epoch 3 step 29200 | 6260 batches | lr 0.000237 | ms/batch 621.80 | loss 7.31 | ppl 1488.194
| epoch 3 step 29400 | 6460 batches | lr 0.000237 | ms/batch 621.68 | loss 7.32 | ppl 1513.990
| epoch 3 step 29600 | 6660 batches | lr 0.000237 | ms/batch 621.60 | loss 7.32 | ppl 1503.472
| epoch 3 step 29800 | 6860 batches | lr 0.000237 | ms/batch 621.45 | loss 7.31 | ppl 1491.761
| epoch 3 step 30000 | 7060 batches | lr 0.000236 | ms/batch 621.66 | loss 7.31 | ppl 1500.846
| epoch 3 step 30200 | 7260 batches | lr 0.000236 | ms/batch 621.62 | loss 7.31 | ppl 1494.195
| epoch 3 step 30400 | 7460 batches | lr 0.000236 | ms/batch 621.88 | loss 7.31 | ppl 1501.704
| epoch 3 step 30600 | 7660 batches | lr 0.000236 | ms/batch 621.50 | loss 7.31 | ppl 1493.181
| epoch 3 step 30800 | 7860 batches | lr 0.000236 | ms/batch 622.06 | loss 7.31 | ppl 1493.227
| epoch 3 step 31000 | 8060 batches | lr 0.000235 | ms/batch 621.64 | loss 7.31 | ppl 1501.180
| epoch 3 step 31200 | 8260 batches | lr 0.000235 | ms/batch 621.87 | loss 7.31 | ppl 1501.493
| epoch 3 step 31400 | 8460 batches | lr 0.000235 | ms/batch 621.95 | loss 7.33 | ppl 1518.169
| epoch 3 step 31600 | 8660 batches | lr 0.000235 | ms/batch 621.71 | loss 7.31 | ppl 1502.388
| epoch 3 step 31800 | 8860 batches | lr 0.000235 | ms/batch 621.56 | loss 7.32 | ppl 1511.796
| epoch 3 step 32000 | 9060 batches | lr 0.000235 | ms/batch 621.97 | loss 7.31 | ppl 1500.025

| Eval 8 at step 32000 | time: 2492.28s | valid loss 7.44 | valid ppl 1708.262

| epoch 3 step 32200 | 9260 batches | lr 0.000234 | ms/batch 648.48 | loss 7.32 | ppl 1502.871
| epoch 3 step 32400 | 9460 batches | lr 0.000234 | ms/batch 621.60 | loss 7.33 | ppl 1527.820
| epoch 3 step 32600 | 9660 batches | lr 0.000234 | ms/batch 621.64 | loss 7.31 | ppl 1492.826
| epoch 3 step 32800 | 9860 batches | lr 0.000234 | ms/batch 621.63 | loss 7.31 | ppl 1495.967
| epoch 3 step 33000 | 10060 batches | lr 0.000234 | ms/batch 622.10 | loss 7.33 | ppl 1524.507
| epoch 3 step 33200 | 10260 batches | lr 0.000233 | ms/batch 621.53 | loss 7.30 | ppl 1485.690
| epoch 3 step 33400 | 10460 batches | lr 0.000233 | ms/batch 621.17 | loss 7.31 | ppl 1492.452
| epoch 3 step 33600 | 10660 batches | lr 0.000233 | ms/batch 621.04 | loss 7.31 | ppl 1498.407
| epoch 3 step 33800 | 10860 batches | lr 0.000233 | ms/batch 621.77 | loss 7.32 | ppl 1512.605
| epoch 3 step 34000 | 11060 batches | lr 0.000233 | ms/batch 621.23 | loss 7.32 | ppl 1504.775
| epoch 3 step 34200 | 11260 batches | lr 0.000232 | ms/batch 621.52 | loss 7.33 | ppl 1523.020
| epoch 3 step 34400 | 11460 batches | lr 0.000232 | ms/batch 620.80 | loss 7.31 | ppl 1491.825
| epoch 4 step 34600 | 190 batches | lr 0.000232 | ms/batch 621.64 | loss 7.29 | ppl 1465.376
| epoch 4 step 34800 | 390 batches | lr 0.000232 | ms/batch 621.66 | loss 7.28 | ppl 1450.339
| epoch 4 step 35000 | 590 batches | lr 0.000232 | ms/batch 621.68 | loss 7.28 | ppl 1457.213
| epoch 4 step 35200 | 790 batches | lr 0.000231 | ms/batch 621.64 | loss 7.28 | ppl 1456.289
| epoch 4 step 35400 | 990 batches | lr 0.000231 | ms/batch 621.70 | loss 7.29 | ppl 1464.555
| epoch 4 step 35600 | 1190 batches | lr 0.000231 | ms/batch 621.51 | loss 7.28 | ppl 1455.969
| epoch 4 step 35800 | 1390 batches | lr 0.000231 | ms/batch 621.76 | loss 7.29 | ppl 1460.029
| epoch 4 step 36000 | 1590 batches | lr 0.000231 | ms/batch 622.02 | loss 7.29 | ppl 1469.949

| Eval 9 at step 36000 | time: 2491.61s | valid loss 7.45 | valid ppl 1727.050

| epoch 4 step 36200 | 1790 batches | lr 0.00023 | ms/batch 648.55 | loss 7.29 | ppl 1465.296
| epoch 4 step 36400 | 1990 batches | lr 0.00023 | ms/batch 621.63 | loss 7.29 | ppl 1470.353
| epoch 4 step 36600 | 2190 batches | lr 0.00023 | ms/batch 621.60 | loss 7.29 | ppl 1466.556
| epoch 4 step 36800 | 2390 batches | lr 0.00023 | ms/batch 621.81 | loss 7.30 | ppl 1481.549
| epoch 4 step 37000 | 2590 batches | lr 0.000229 | ms/batch 621.74 | loss 7.29 | ppl 1459.671
| epoch 4 step 37200 | 2790 batches | lr 0.000229 | ms/batch 622.05 | loss 7.30 | ppl 1475.031
| epoch 4 step 37400 | 2990 batches | lr 0.000229 | ms/batch 621.83 | loss 7.29 | ppl 1471.230
| epoch 4 step 37600 | 3190 batches | lr 0.000229 | ms/batch 621.65 | loss 7.28 | ppl 1444.171
| epoch 4 step 37800 | 3390 batches | lr 0.000229 | ms/batch 621.73 | loss 7.27 | ppl 1439.950
| epoch 4 step 38000 | 3590 batches | lr 0.000228 | ms/batch 621.45 | loss 7.28 | ppl 1454.786
| epoch 4 step 38200 | 3790 batches | lr 0.000228 | ms/batch 622.01 | loss 7.30 | ppl 1479.319
| epoch 4 step 38400 | 3990 batches | lr 0.000228 | ms/batch 621.87 | loss 7.28 | ppl 1451.045
| epoch 4 step 38600 | 4190 batches | lr 0.000228 | ms/batch 622.00 | loss 7.30 | ppl 1475.214
| epoch 4 step 38800 | 4390 batches | lr 0.000227 | ms/batch 621.46 | loss 7.29 | ppl 1460.207
| epoch 4 step 39000 | 4590 batches | lr 0.000227 | ms/batch 621.24 | loss 7.29 | ppl 1462.538
| epoch 4 step 39200 | 4790 batches | lr 0.000227 | ms/batch 621.45 | loss 7.29 | ppl 1468.254
| epoch 4 step 39400 | 4990 batches | lr 0.000227 | ms/batch 621.85 | loss 7.31 | ppl 1493.102
| epoch 4 step 39600 | 5190 batches | lr 0.000227 | ms/batch 622.29 | loss 7.31 | ppl 1493.513
| epoch 4 step 39800 | 5390 batches | lr 0.000226 | ms/batch 621.24 | loss 7.29 | ppl 1459.525
| epoch 4 step 40000 | 5590 batches | lr 0.000226 | ms/batch 621.47 | loss 7.30 | ppl 1474.168

| Eval 10 at step 40000 | time: 2492.18s | valid loss 7.46 | valid ppl 1730.644

| epoch 4 step 40200 | 5790 batches | lr 0.000226 | ms/batch 648.23 | loss 7.31 | ppl 1495.160
| epoch 4 step 40400 | 5990 batches | lr 0.000226 | ms/batch 621.79 | loss 7.30 | ppl 1486.213
| epoch 4 step 40600 | 6190 batches | lr 0.000225 | ms/batch 621.62 | loss 7.29 | ppl 1462.185
| epoch 4 step 40800 | 6390 batches | lr 0.000225 | ms/batch 621.89 | loss 7.30 | ppl 1473.741
| epoch 4 step 41000 | 6590 batches | lr 0.000225 | ms/batch 621.52 | loss 7.31 | ppl 1488.068
| epoch 4 step 41200 | 6790 batches | lr 0.000225 | ms/batch 621.60 | loss 7.29 | ppl 1458.729
| epoch 4 step 41400 | 6990 batches | lr 0.000224 | ms/batch 621.76 | loss 7.30 | ppl 1480.750
| epoch 4 step 41600 | 7190 batches | lr 0.000224 | ms/batch 621.84 | loss 7.29 | ppl 1471.271
| epoch 4 step 41800 | 7390 batches | lr 0.000224 | ms/batch 621.72 | loss 7.30 | ppl 1479.743
| epoch 4 step 42000 | 7590 batches | lr 0.000224 | ms/batch 621.64 | loss 7.30 | ppl 1473.643
| epoch 4 step 42200 | 7790 batches | lr 0.000224 | ms/batch 621.37 | loss 7.28 | ppl 1447.959
| epoch 4 step 42400 | 7990 batches | lr 0.000223 | ms/batch 621.54 | loss 7.31 | ppl 1495.434
| epoch 4 step 42600 | 8190 batches | lr 0.000223 | ms/batch 621.55 | loss 7.29 | ppl 1462.527
| epoch 4 step 42800 | 8390 batches | lr 0.000223 | ms/batch 621.60 | loss 7.31 | ppl 1489.709
| epoch 4 step 43000 | 8590 batches | lr 0.000223 | ms/batch 621.65 | loss 7.30 | ppl 1484.012
| epoch 4 step 43200 | 8790 batches | lr 0.000222 | ms/batch 621.62 | loss 7.31 | ppl 1491.229
| epoch 4 step 43400 | 8990 batches | lr 0.000222 | ms/batch 621.14 | loss 7.30 | ppl 1476.093
| epoch 4 step 43600 | 9190 batches | lr 0.000222 | ms/batch 621.67 | loss 7.31 | ppl 1487.912
| epoch 4 step 43800 | 9390 batches | lr 0.000222 | ms/batch 621.81 | loss 7.30 | ppl 1474.515
| epoch 4 step 44000 | 9590 batches | lr 0.000221 | ms/batch 621.78 | loss 7.31 | ppl 1489.621

| Eval 11 at step 44000 | time: 2491.86s | valid loss 7.44 | valid ppl 1706.983

| epoch 4 step 44200 | 9790 batches | lr 0.000221 | ms/batch 648.36 | loss 7.29 | ppl 1464.462
| epoch 4 step 44400 | 9990 batches | lr 0.000221 | ms/batch 621.91 | loss 7.31 | ppl 1492.388
| epoch 4 step 44600 | 10190 batches | lr 0.000221 | ms/batch 621.52 | loss 7.29 | ppl 1465.266
| epoch 4 step 44800 | 10390 batches | lr 0.00022 | ms/batch 621.12 | loss 7.30 | ppl 1477.299
| epoch 4 step 45000 | 10590 batches | lr 0.00022 | ms/batch 620.76 | loss 7.29 | ppl 1464.611
| epoch 4 step 45200 | 10790 batches | lr 0.00022 | ms/batch 621.72 | loss 7.30 | ppl 1475.448
| epoch 4 step 45400 | 10990 batches | lr 0.00022 | ms/batch 621.22 | loss 7.31 | ppl 1493.140
| epoch 4 step 45600 | 11190 batches | lr 0.000219 | ms/batch 621.56 | loss 7.30 | ppl 1486.864
| epoch 4 step 45800 | 11390 batches | lr 0.000219 | ms/batch 621.13 | loss 7.29 | ppl 1470.398
| epoch 5 step 46000 | 120 batches | lr 0.000219 | ms/batch 621.16 | loss 7.29 | ppl 1462.640
| epoch 5 step 46200 | 320 batches | lr 0.000219 | ms/batch 622.12 | loss 7.27 | ppl 1429.701
| epoch 5 step 46400 | 520 batches | lr 0.000218 | ms/batch 622.40 | loss 7.28 | ppl 1455.915
| epoch 5 step 46600 | 720 batches | lr 0.000218 | ms/batch 621.56 | loss 7.27 | ppl 1429.803
| epoch 5 step 46800 | 920 batches | lr 0.000218 | ms/batch 621.79 | loss 7.28 | ppl 1447.379
| epoch 5 step 47000 | 1120 batches | lr 0.000217 | ms/batch 621.45 | loss 7.28 | ppl 1449.542
| epoch 5 step 47200 | 1320 batches | lr 0.000217 | ms/batch 621.70 | loss 7.28 | ppl 1444.158
| epoch 5 step 47400 | 1520 batches | lr 0.000217 | ms/batch 622.01 | loss 7.27 | ppl 1441.529
| epoch 5 step 47600 | 1720 batches | lr 0.000217 | ms/batch 621.49 | loss 7.28 | ppl 1456.560
| epoch 5 step 47800 | 1920 batches | lr 0.000216 | ms/batch 621.47 | loss 7.27 | ppl 1440.568
| epoch 5 step 48000 | 2120 batches | lr 0.000216 | ms/batch 621.64 | loss 7.29 | ppl 1461.721

| Eval 12 at step 48000 | time: 2491.63s | valid loss 7.47 | valid ppl 1752.609

ZeroDivisionError: integer division or modulo by zero

ub16c9@ub16c9-gpu:~/ub16_prj/transformer-xl/pytorch$ bash run_enwik8_base.sh train --work_dir enwiki8_task
Run training...
Experiment dir : enwiki8_task-enwik8/20190127-180347
Loading cached dataset...

- clip : 0.25
- eta_min : 0.0
- finetune_v3 : False
- n_layer : 12
- pre_lnorm : False
- n_head : 8
- proj_init_std : 0.01
- emb_init_range : 0.01
- fp16 : False
- n_nonemb_param : 40949760
- scheduler : cosine
- work_dir : enwiki8_task-enwik8/20190127-180347
- batch_size : 22
- debug : False
- dropatt : 0.0
- init_std : 0.02
- lr : 0.00025
- cuda : True
- data : ../data/enwik8/
- emb_init : normal
- ext_len : 0
- sample_softmax : -1
- eval_tgt_len : 128
- restart_dir : 
- mom : 0.0
- clamp_len : -1
- max_eval_steps : -1
- batch_chunk : 1
- multi_gpu : True
- mem_len : 512
- dynamic_loss_scale : False
- d_embed : 512
- max_step : 400000
- attn_type : 0
- lr_min : 0.0
- static_loss_scale : 1
- init : normal
- patience : 0
- dropout : 0.1
- finetune_v2 : False
- d_head : 64
- same_length : False
- dataset : enwik8
- init_range : 0.1
- d_model : 512
- tgt_len : 512
- optim : adam
- d_inner : 2048
- warmup_step : 0
- restart : False
- seed : 1111
- adaptive : False
- n_token : 204
- log_interval : 200
- varlen : False
- tied : True
- clip_nonemb : False
- decay_rate : 0.5
- div_val : 1
- gpu0_bsz : 4
- not_tied : False
- eval_interval : 4000
- n_all_param : 41055436

====================================================================================================
#params = 41055436
#non emb params = 40949760
Traceback (most recent call last):
File "train.py", line 539, in
train()
File "train.py", line 445, in train
ret = para_model(data, target, *mems)
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/ub16c9/ub16_prj/transformer-xl/pytorch/utils/data_parallel.py", line 64, in forward
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
File "/home/ub16c9/ub16_prj/transformer-xl/pytorch/utils/data_parallel.py", line 80, in scatter
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
ZeroDivisionError: integer division or modulo by zero
ub16c9@ub16c9-gpu:~/ub16_prj/transformer-xl/pytorch$

Penn Treebank and WikiText-2 architectures

Hello!

Could you, please, provide hyperparameters for training models with close to SOTA perplexity on PTB and WT2 (if you experimented with the latter, as it has the corresponding choice in data utils)? Am I right that two changes I need to make to the released code is to add variational dropout and ASGD optimizer? If you have a code which produces necessary changes, it would be great.

Thanks

problem when run sota/enwik8.sh

when i use bash sota/enwik8.sh

Preprocess test set...
Loading cached dataset...
Traceback (most recent call last):
File "data_utils.py", line 586, in
tf.app.run(main)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "data_utils.py", line 382, in main
corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)
File "data_utils.py", line 345, in get_lm_corpus
corpus = pickle.load(fp)
UnicodeDecodeError: 'ascii' codec can't decode byte 0x88 in position 153592: ordinal not in range(128)
Run evaluation on test set...
I0304 14:20:51.687984 139744890042112 tf_logging.py:115] n_token 204
Traceback (most recent call last):
File "train_gpu.py", line 475, in
tf.app.run()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "train_gpu.py", line 471, in main
evaluate(n_token, cutoffs, "/gpu:0")
File "train_gpu.py", line 366, in evaluate
use_tpu=False)
File "/home/gaodihe/PycharmProjects/transformer-xl/tf/data_utils.py", line 424, in get_input_fn
num_core_per_host, use_tpu=use_tpu)
File "/home/gaodihe/PycharmProjects/transformer-xl/tf/data_utils.py", line 415, in load_record_info
with open(record_info_path, "r") as fp:
FileNotFoundError: [Errno 2] No such file or directory: './/pretrained_xl/tf_enwik8/data/tfrecords/record_info-test.bsz-16.tlen-128.json'

this error happens.

How could i fix it?
Thanks

PyTorch: pretrained models

Hi,

thanks for the releasing the TensorFlow and PyTorch code for your Transformer-XL ❤️

I would like to ask, if you plan to provide some pre-trained models for the PyTorch implementation? I was only able to find the TensorFlow checkpoints...

Thanks in advance,

Stefan

Sensitivity to initial weights causing NANs?

Hi, I'm getting NAN values in the first forward pass of the model (in the first layer), generally caused by the first AC calculation. I'm wondering if this is an issue with the initial weights of the model? If so, any advice to help with this issue? I have made some changes to the model and this will help me determine if this is a known issue or if I have introduced a bug. Thanks.

Finetuning

I noticed that there are two flags in the pytorch train.py script (--finetune_v2 and --finetune_v3) that don't seem to be used in any of the code. These flags suggest that there might be something special I need to do for finetuning Transformer XL. Might I be missing something?

Currently, I am running finetuning experiments simply by specifying --restart as well as --restart_dir and changing the dataset

OOM issue when training 1 billion corpus

I am trying to train with 1 billion corpus on Tesla P40. Following are the values being used.

N_LAYER = 12
D_MODEL = 512,
D_EMBED = 512,
D_INNER = 2048,
D_HEAD = 64

I also tried with a BSZ of 128, it still gives OOM error.

TPU settings

Hi,

I would like to train a model on TPU, but I'm not able to find the correct settings for a v2-8 TPU.

What parameters are needed for NUM_HOST and NUM_CORE? I tried different values, but I always get num_replicas should be (8), got (XXX). error messages.

What TPU model did you use for the 1 Billion word benchmark?

Can I create the tfrecords locally (on a non-TPU) in the train_data step?

Thanks :)

Train a new corpus !

What changes we need to perform inside the script to train in a new corpus ?

I have checked the script and there is a lot of if condition depend on each corpus.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.