Git Product home page Git Product logo

trax's Introduction

Trax — Deep Learning with Clear Code and Speed

train tracks PyPI version GitHub Issues GitHub Build Contributions welcome License Gitter

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

Here are a few example notebooks:-

General Setup

Execute the following cell (once) before running any of the code samples.

import os
import numpy as np

!pip install -q -U trax
import trax

1. Run a pre-trained Transformer

Here is how you create an English-German translator in a few lines of code:

# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.

from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix  = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix = 
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

class Embedding(base.Layer):
  """Trainable layer that maps discrete tokens/IDs to vectors."""

  def __init__(self,
               vocab_size,
               d_feature,
               kernel_initializer=init.RandomNormalInitializer(1.0)):
    """Returns an embedding layer with given vocabulary size and vector size.

    Args:
      vocab_size: Size of the input vocabulary. The layer will assign a unique
          vector to each ID in `range(vocab_size)`.
      d_feature: Dimensionality/depth of the output vectors.
      kernel_initializer: Function that creates (random) initial vectors for
          the embedding.
    """
    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
    self._d_feature = d_feature  # feature dimensionality
    self._vocab_size = vocab_size
    self._kernel_initializer = kernel_initializer

  def forward(self, x):
    """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0, mode='clip')

  def init_weights_and_state(self, input_signature):
    """Returns tensor of newly initialized embedding vectors."""
    del input_signature
    shape_w = (self._vocab_size, self._d_feature)
    w = self._kernel_initializer(shape_w, self.rng)
    self.weights = w

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

from trax import layers as tl

# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')

# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y = (15, 32)

Models

Models in Trax are built from layers most often using the Serial and Branch combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
    tl.LogSoftmax()   # Produce log-probabilities.
)

# You can print model structure.
print(model)
Serial[
  Embedding_8192_256
  Mean
  Dense_2
  LogSoftmax
]

Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream) and get a tuple, e.g., (inputs, targets). Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt').

train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream))  # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)

Using the trax.data module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial and they are functions that you apply to streams to create processed streams.

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[256,  64,  16,    4, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.
shapes = [(4, 1024), (4,), (4,)]

Supervised training

When you have the model and the data, use trax.supervised.training to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)
Step      1: Ran 1 train steps in 0.78 secs
Step      1: train WeightedCategoryCrossEntropy |  1.33800304
Step      1: eval  WeightedCategoryCrossEntropy |  0.71843582
Step      1: eval      WeightedCategoryAccuracy |  0.56562500

Step    500: Ran 499 train steps in 5.77 secs
Step    500: train WeightedCategoryCrossEntropy |  0.62914723
Step    500: eval  WeightedCategoryCrossEntropy |  0.49253047
Step    500: eval      WeightedCategoryAccuracy |  0.74062500

Step   1000: Ran 500 train steps in 5.03 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.42949259
Step   1000: eval  WeightedCategoryCrossEntropy |  0.35451687
Step   1000: eval      WeightedCategoryAccuracy |  0.83750000

Step   1500: Ran 500 train steps in 4.80 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.41843575
Step   1500: eval  WeightedCategoryCrossEntropy |  0.35207348
Step   1500: eval      WeightedCategoryAccuracy |  0.82109375

Step   2000: Ran 500 train steps in 5.35 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.38129005
Step   2000: eval  WeightedCategoryCrossEntropy |  0.33760912
Step   2000: eval      WeightedCategoryAccuracy |  0.85312500

After training the model, run it like any layer to get results.

example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: I first saw this when I was a teen in my last year of Junior High. I was riveted to it! I loved the special effects, the fantastic places and the trial-aspect and flashback method of telling the story.<br /><br />Several years later I read the book and while it was interesting and I could definitely see what Swift was trying to say, I think that while it's not as perfect as the book for social commentary, as a story the movie is better. It makes more sense to have it be one long adventure than having Gulliver return after each voyage and making a profit by selling the tiny Lilliput sheep or whatever.<br /><br />It's much more arresting when everyone thinks he's crazy and the sheep DO make a cameo anyway. As a side note, when I saw Laputa I was stunned. It looks very much like the Kingdom of Zeal from the Chrono Trigger video game (1995) that also made me like this mini-series even more.<br /><br />I saw it again about 4 years ago, and realized that I still enjoyed it just as much. Really high quality stuff and began an excellent run of Sweeps mini-series for NBC who followed it up with the solid Merlin and interesting Alice in Wonderland.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[3.984500e-04 9.996014e-01]]

trax's People

Contributors

adarob avatar afrozenator avatar cclauss avatar hawkinsp avatar henrykmichalewski avatar j2i2 avatar jalammar avatar kkanska avatar koz4k avatar lukaszkaiser avatar manifest avatar modyharshit23 avatar mtyrolski avatar nathanhowell avatar omaralsaqa avatar piotrnawrot avatar piotrpiekos avatar pkol avatar pkozakowski avatar pschuh avatar sauravmaheshkar avatar sebastianjaszczur avatar shadowatyyy avatar sunvod avatar syzymon avatar trax-robot avatar wangpengmit avatar weiddeng avatar yashkhasbage25 avatar yilei avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

trax's Issues

Orientation / GCP Implementation / TPU / Migration from t2t

Description

Hello,

please help us understand where we are heading with t2t being discontinued. We have a lot of interacting scripts with the t2t-eco system. Please help us understand what the perspective is for trax as a replacement. Will there be ,migration help? How does it integrate with tpus on GCP. Couldn't find any announcements what t2t users should expect.

Thx
Phillip

Anyway to disable jit for debugging purposes?

I'm editing the trax codebase and for debugging purposes I have to print values eagerly. But because of the jit, everything is a jit abstract expression(google/jax#196). How do you guys debug the code without being able to print anything? I assume if i disable the jit, everything would be executed eagerly. Am i right? If yes, is there anyway to disable jit?

Training Transformer on TPU

Description

Hello, I was wondering, how large can the batch size be considering TPU training? Now I'm training vanilla Transformer model in Colab and I can barely fit TPU memory. My batch size is 128, sequences are padded with padded_batch function, max_len is 512. It seems to me that I'm missing something, because it's a bit suspicious that TPU cannot handle batches of higher magnitude (like 2048).

The thing that I tried to establish is to run TPU profiler, but I could not do it since the model doesn't output anything to keep track of.

That's why, my question is, what are the best practices of training Trax transformer on TPUs?

Error log

RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 12.64G of 8.00G hbm. Exceeded hbm capacity by 4.64G.

Total hbm usage >= 12.64G:
    reserved        529.00M 
    program          12.13G 
    arguments       unknown size 

Output size unknown.

Program hbm requirement 12.13G:
    reserved           4.0K
    global           196.0K
    HLO temp         12.13G (58.5% utilization: Unpadded (7.09G) Padded (12.12G), 0.1% fragmentation (10.34M))

  Largest program allocations in hbm:

  1. Size: 937.50M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n                                 precision=None ]"
     Shape: f32[64,128,30000]{1,2,0:T(8,128)}
     Unpadded size: 937.50M
     XLA label: %fusion.1546 = (f32[64,128]{1,0:T(8,128)}, f32[64,128]{1,0:T(8,128)}, f32[64,128,30000]{1,2,0:T(8,128)}) fusion(f32[64,128]{1,0:T(8,128)} %fusion.9002.remat3, f32[64,128]{1,0:T(8,128)} %fusion.28213.remat, f32[30000]{0:T(1024)} %get-tuple-element.4759, f32...
     Allocation type: HLO temp
     ==========================

  2. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.249 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4769), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  3. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.248 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4765), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  4. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.247 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4761), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  5. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.246 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4757), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  6. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.245 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4753), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  7. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.244 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4747), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  8. Size: 512.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,512,512]{2,3,1,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.2186 = (f32[64,8,512]{2,1,0:T(8,128)}, f32[64,8,512]{2,1,0:T(8,128)}, f32[64,8,512,512]{2,3,1,0:T(8,128)}) fusion(f32[64,8,512]{2,1,0:T(8,128)} %fusion.2753, pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4382, f32[64,8,512]{2,1,0:T(8,128)} %fu...
     Allocation type: HLO temp
     ==========================

  9. Size: 512.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,512,512]{2,3,1,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %convolution-base-dilated.117.remat5 = f32[64,8,512,512]{2,3,1,0:T(8,128)} convolution(bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.312, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.314), window={size=64x8 stride=63x7 lhs_dilate=64x8}, dim_labels...
     Allocation type: HLO temp
     ==========================

  10. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4751 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3020), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  11. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4755 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3021), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  12. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4759 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3022), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  13. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4763 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3023), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  14. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4767 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3024), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  15. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4771 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3025), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  16. Size: 128.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
     Unpadded size: 32.00M
     Extra memory due to padding: 96.00M (4.0x expansion)
     XLA label: %reshape.4740 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2426), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  17. Size: 128.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
     Unpadded size: 32.00M
     Extra memory due to padding: 96.00M (4.0x expansion)
     XLA label: %reshape.4745 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2431), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  18. Size: 128.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
     Unpadded size: 32.00M
     Extra memory due to padding: 96.00M (4.0x expansion)
     XLA label: %reshape.4791 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2427), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  19. Size: 128.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,128,512]{3,2,1,0:T(8,128)}
     Unpadded size: 128.00M
     XLA label: %fusion.4304 = (f32[64,8,128]{2,1,0:T(8,128)}, f32[64,8,128,512]{3,2,1,0:T(8,128)}) fusion(pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4384, bf16[64,128,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.85, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.389)...
     Allocation type: HLO temp
     ==========================

  20. Size: 128.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,128,512]{3,2,1,0:T(8,128)}
     Unpadded size: 128.00M
     XLA label: %fusion.4305 = (f32[64,8,128]{2,1,0:T(8,128)}, f32[64,8,128,512]{3,2,1,0:T(8,128)}) fusion(pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4384, bf16[64,128,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.464, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.466...
     Allocation type: HLO temp
     ==========================

IndexError: list assignment index out of range in Transformer model

Description

Im attempting to train a Transformer model for machine translation with a shared vocabulary. As expected the input and target sequences are different lengths. I was expecting Trax to detect and pad the sequences accordingly. I didn't see examples or documentation for this exact problem. Any advice would be greatly appreciated.

Environment information

OS: MacOS 10.14.6 (18G3020)

$ pip freeze | grep tensor
mesh-tensorflow==0.1.12
tensor2tensor==1.15.4
tensorboard==2.1.1
tensorflow==2.1.0
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.62
jaxlib==0.1.42

$ python -V
Python 3.7.6

For bugs: reproduction and error logs

# Steps to reproduce:
1. Download a parallel text corpus. 
2. Create a vocabulary and tokenize the source and target text and save as TFRecords.
3. Run this following code to train a Transformer model:
import os

import tensorflow as tf
import trax

from src.common.params import Paths


def train_model(inputs: trax.supervised.Inputs, model_function, output_dir):

  trainer = trax.supervised.Trainer(
    model=model_function,
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adafactor,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=inputs,
    output_dir=output_dir
  )

  n_epochs = 10
  train_steps = 10
  eval_steps = 1
  for _ in range(n_epochs):
    trainer.train_epoch(train_steps, eval_steps)


def parse_example(serialized_example):
  """Return inputs and targets Tensors from a serialized tf.Example."""
  data_fields = {
      "inputs": tf.io.VarLenFeature(tf.int64),
      "targets": tf.io.VarLenFeature(tf.int64)
  }
  parsed = tf.io.parse_single_example(serialized_example, data_fields)
  inputs = tf.sparse.to_dense(parsed["inputs"])
  targets = tf.sparse.to_dense(parsed["targets"])
  return inputs, targets


def file_length(filename):
  with open(filename) as f:
    for i, l in enumerate(f):
      pass
  return i + 1


def main():

  ML_ROOT = os.path.join(Paths.data_root, 'machine_translation')
  trax_path = os.path.join(ML_ROOT, 'trax')
  tokenizer_path = os.path.join(trax_path, 'subtoken.vocab')
  tokenized_records_path = os.path.join(trax_path, 'tokenized')
  os.makedirs(trax_path, exist_ok=True)
  os.makedirs(tokenized_records_path, exist_ok=True)

  tf_record_filenames = [os.path.join(tokenized_records_path, p) for p in
                         os.listdir(tokenized_records_path)]

  dataset = tf.data.TFRecordDataset(tf_record_filenames).map(parse_example)

  inputs = trax.supervised.Inputs(
    train_stream=lambda _: dataset.as_numpy_iterator(),
    eval_stream=lambda _: dataset.as_numpy_iterator()
  )

  # Peek into the inputs.
  data_stream = inputs.train_stream(n_devices=1)
  for _ in range(10):
    sample_input, sample_target = next(data_stream)
    print('-' * 100)
    print("Inputs:  %s, len: %s" % (str(sample_input), str(len(sample_input))))
    print("Targets: %s, len: %s" % (str(sample_target), str(len(sample_target))))

  vocab_size = file_length(tokenizer_path)
  print('Vocab size:', vocab_size)

  def transformer(mode):
    return trax.models.Transformer(
      vocab_size,
      mode=mode
    )

  train_model(inputs, model_function=transformer, output_dir=os.path.expanduser('~/train_dir/'))


if __name__ == '__main__':
  main()

Other output:

----------------------------------------------------------------------------------------------------
Inputs:  [704 656  32 769   2 588 820 936   2  47   4   1], len: 12
Targets: [946 947 950 942 937 462   2   7   5 238  21 377 336   4   1], len: 15
----------------------------------------------------------------------------------------------------
Inputs:  [798 128 221 866   2 249  37 471 912   4   1], len: 11
Targets: [946 947 950 944 937   2 338 188 190 383 301 106   4   1], len: 14
----------------------------------------------------------------------------------------------------
Inputs:  [ 55 641  34 425 685  53 426 391 356 426   4   1], len: 12
Targets: [946 947 948 953 937 216  10  46 230 104 196 172 364 187 187  21   4   1], len: 18
----------------------------------------------------------------------------------------------------
Inputs:  [167 475  45  20 122 139  48   2  56 148  76  56 246  33  53 424 299 209
 220 687   2   4   1], len: 23
Targets: [988 240  64 127 719  29  92 346 380 109 206 292 163 378   5  67   4   1], len: 18
----------------------------------------------------------------------------------------------------
Inputs:  [ 66 156  84 687  43 246  56 687   2  40 246  81 687 739 258  56 148  30
 687  50 156  83  54 246  56 687   2 697 156  30 687  45   4   1], len: 34
Targets: [946 947 950 944 937   2  71 698 980 932 827 827   2  13 322   4   1], len: 17
----------------------------------------------------------------------------------------------------
Inputs:  [129 665 256 145 470   4   1], len: 7
Targets: [946 947 950 944 937   2 802   5  12   4   1], len: 11
----------------------------------------------------------------------------------------------------
Inputs:  [225 255 861  85 148  45  28  56 148  50 687  95 246  62 148 165   2 431
  51 181 246 393 739  86  66 148 432  41   4   1], len: 30
Targets: [946 947 950 944 937   2 123 380 304  77  11 745  16  67 153 709 126 618
 462   2   4   1], len: 22
----------------------------------------------------------------------------------------------------
Inputs:  [438 118  95 128 903 162  69 282 239   4   1], len: 11
Targets: [946 947 950 944 937  10  18 311 394 404 311 778 119 238  64  27   4   1], len: 18
----------------------------------------------------------------------------------------------------
Inputs:  [101 602 195 310  37  19   1], len: 7
Targets: [946 947 942 951 937 261  12 231 261  17   1], len: 11
----------------------------------------------------------------------------------------------------
Inputs:  [162 253 481 128  78 141 161 145   4   1], len: 10
Targets: [946 947 950 944 937  10  91 213 357   5 106   6 410 189 613  10   4   1], len: 18
Vocab size: 992
# Error logs:
/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Traceback (most recent call last):
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 480, in _forward_abstract
    input_signature, weight_signature, self.state, rng)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/math/jax.py", line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 2104, in eval_shape
    *map(abstractify, args_flat))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 274, in abstract_eval_fun
    instantiate=True)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 220, in forward_with_state
    return self.forward(inputs, weights), state
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/attention.py", line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1
IndexError: list assignment index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
    weights, state = self.new_weights_and_state(input_signature)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 485, in _forward_abstract
    trace)
trax.layers.base.LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
  layer created in file [...]/trax/models/transformer.py, line 291
  layer input shapes: ShapeDtype{shape:(1,), dtype:int32}

  File [...]/trax/math/jax.py, line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)

  File [...]/site-packages/jax/api.py, line 2104, in eval_shape
    *map(abstractify, args_flat))

  File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
    instantiate=True)

  File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/trax/layers/base.py, line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)

  File [...]/trax/layers/base.py, line 220, in forward_with_state
    return self.forward(inputs, weights), state

  File [...]/trax/layers/base.py, line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access

  File [...]/trax/layers/attention.py, line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1

IndexError: list assignment index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
    weights, state = self.new_weights_and_state(input_signature)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state
    weights_or_empty, state = sublayer.init(inputs)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
    input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/models/transformer.py, line 301
  layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})

  File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
  layer created in file [...]/trax/models/transformer.py, line 291
  layer input shapes: ShapeDtype{shape:(1,), dtype:int32}

  File [...]/trax/math/jax.py, line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)

  File [...]/site-packages/jax/api.py, line 2104, in eval_shape
    *map(abstractify, args_flat))

  File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
    instantiate=True)

  File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/trax/layers/base.py, line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)

  File [...]/trax/layers/base.py, line 220, in forward_with_state
    return self.forward(inputs, weights), state

  File [...]/trax/layers/base.py, line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access

  File [...]/trax/layers/attention.py, line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1

IndexError: list assignment index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "src/machine_translation/trax/main.py", line 97, in <module>
    main()
  File "src/machine_translation/trax/main.py", line 93, in main
    output_dir=os.path.expanduser('~/train_dir/'))
  File "src/machine_translation/trax/main.py", line 21, in train_model
    output_dir=output_dir
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 217, in __init__
    self.reset(output_dir)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 297, in reset
    opt_state, model_state = self._new_opt_state_and_model_state()
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 170, in <lambda>
    model_target_shape, self._inputs.target_dtype, init_rng))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
    name=flat_fun.__name__)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/core.py", line 895, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 457, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 220, in memoized_fun
    ans = call(fun, *args)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 474, in _xla_callable
    fun, pvals, instantiate=False, stage_out_calls=True, bottom=True)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 159, in new_opt_state_and_model_state
    weights, state = m.init(input_signature)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
    input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/supervised/trainer_lib.py, line 157
  layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})

  File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
    weights_or_empty, state = sublayer.init(inputs)

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/models/transformer.py, line 301
  layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})

  File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
  layer created in file [...]/trax/models/transformer.py, line 291
  layer input shapes: ShapeDtype{shape:(1,), dtype:int32}

  File [...]/trax/math/jax.py, line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)

  File [...]/site-packages/jax/api.py, line 2104, in eval_shape
    *map(abstractify, args_flat))

  File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
    instantiate=True)

  File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/trax/layers/base.py, line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)

  File [...]/trax/layers/base.py, line 220, in forward_with_state
    return self.forward(inputs, weights), state

  File [...]/trax/layers/base.py, line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access

  File [...]/trax/layers/attention.py, line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1

IndexError: list assignment index out of range

[LSHSelfAttention] Why does q_start has to be of type int in incremental_forward_unbatched?

Question:

In line:

if isinstance(q_start, int) and q_start == 0 and q_len > 1:

q_start is checked to be of type int. This means that if q_start = DeviceArray(0, dtype=int32) e.g. the code jumps directly into "handling one token at a time and it is checked that q_len == 1.

Removing the isinstance(q_start, int) line works fine if I want to handle more than one token at a time.

What is this isinstance check good for?

[Feature request] Examples for advanced machine learning tasks

Since Trax is a successor of tensor2tensor (according to the release notes of tensor2tensor v1.15.0), it would be helpful if you could provide examples for more advanced machine learning tasks. An outstanding feature of tensor2tensor are the numerous (and useful) examples which Trax is currently lacking. Such examples would especially be helpful for machine learning tasks with complex input transformations like speech recognition or translation with subword encodings.

Multi-gpu training

Description

Hi. I trying to train ReformerLM. I get code for training loop from this tutorial https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb#scrollTo=djTiSLcaNFGa. Training start normally, but Trax doesn't utilise second gpu at all. Model model is loading on the second gpu, but GPU-Util always on 0%, when first gpu utilize - 100%.
I tried to change batch size(now, I set in to 8), but if I change It to 10, training failed with OOM error.

Can you, please provide code for multi-gpu training?

Environment information

OS: ubuntu 18.04

$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.0.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39

$ python -V
Python 3.7.6

[Feature request] Parallel evaluation

Right now training and evaluation are interleaved. This means that increasing the eval steps or frequency slows down training.
It would be great to have the option to spawn a separate process for evaluation.

Tutorial on how to extract reformer layers as tf.layers

hello, do you have any tutorial on how to extract the reversible self attention layer as a tf.layer? Is it possible? Could it be possible to just take the self attention layer and integrate it to Bert? It would be amazing! Or any tutorial on bow to integrate jax with tf also will be amazing. Thanks!

Custom embedding in Transformer

Description

Is there a way to incorporate custom input embedding while retaining the abstraction of the library for Transformers?

Colab Reformer Prediction Could not allocate bytes in memory

Description

After using colab for training/loading model into prediction mode, runs out of memory on second prediction run on TPU runtime
https://colab.research.google.com/drive/1v2q5Qp2-68hLG-uTZ3gZZHvkm9Ovbpkc

Reformer model details:

def reformer(mode):
  return trax.models.reformer.ReformerLM(
    d_model=32,
    d_ff=128,
    n_layers=8,
    vocab_size=1024,
    mode=mode)

Sequence Length = 100
batch size = 128
...

Environment information

OS: Google Colab

$ pip freeze | grep tensor
mesh-tensorflow==0.1.13
tensor2tensor==1.15.4
tensorboard==2.2.0
tensorboard-plugin-wit==1.6.0.post2
tensorboardcolab==0.0.22
tensorflow==2.2.0rc2
tensorflow-addons==0.8.3
tensorflow-datasets==2.1.0
tensorflow-estimator==2.2.0rc0
tensorflow-gan==2.0.0
tensorflow-gcs-config==2.1.8
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-privacy==0.2.2
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39

$ python -V
Python 3.6.9

For bugs: reproduction and error logs

# Steps to reproduce:
Run all cells upto the "Speed" markdown cell
# Error logs:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    443       else:
--> 444         outputs, s = self._do_custom_gradients(x, weights, state, rng=rng)
    445       self._state = s

16 frames
RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available

During handling of the above exception, another exception occurred:

LayerError                                Traceback (most recent call last)
LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 802
  layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})

  File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
    output, state = _do_forward(x, weights)

  File [...]/dist-packages/jax/api.py, line 1460, in __call__
    num_consts=len(consts))

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
    return compiled_fun(*args)

  File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
    out_buf = compiled.Execute(input_bufs)

RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available

During handling of the above exception, another exception occurred:

LayerError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    449       name, trace = self.__class__.__name__, _short_traceback()
    450       raise LayerError(name, 'pure_fn',
--> 451                        self._caller, signature(x), trace)
    452 
    453   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 811
  layer input shapes: ShapeDtype{shape:(100, 1), dtype:int32}

  File [...]/trax/layers/combinators.py, line 77, in forward_with_state
    outputs, s = layer.pure_fn(inputs, w, s, rng)

LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 802
  layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})

  File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
    output, state = _do_forward(x, weights)

  File [...]/dist-packages/jax/api.py, line 1460, in __call__
    num_consts=len(consts))

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
    return compiled_fun(*args)

  File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
    out_buf = compiled.Execute(input_bufs)

RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available

LSH for Enc-Dec attention

Have you implemented LSH for Enc-Dec attention? I know that the motivation behind full attention was that Enc-Dec is mostly used for MT and full-attention for Enc-Dec should be OK. But i'm using it for larger sequences and I'm hitting OOM issue. Wanted to know if you have implemented LSH for Enc-Dec attention.

Training speed of NMT models

Description

I modified the ende reformer config to train my own reformer model for a low resource language pair (18000 sentence pairs). Note that I am using GPUs and not TPUs. I found that the reformer encoder-decoder trains very slowly (1 second per batch on a V100). Is this normal? I was under the impression that the reformer trains fast. Am I missing something?
...

Environment information

OS: CentOS

$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.14.0
tensorboard==1.15.0
tensorflow-datasets==1.3.2
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.0
tensorflow-probability==0.7.0


$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37
I have installed the GPU versions of this.

$ python -V
Python 3.6.8

For bugs: reproduction and error logs

# Steps to reproduce:
...

I run: python -m trax.trainer --config_file=$PWD/trax/configs/reformer_wmt_ende.gin

# Error logs:
...

N/A

tensorflow.train import crashes in bert.py

Description

In models/bert.py the line
from tensorflow.train import load_checkpoint
crashes with
Traceback (most recent call last):
File "math_trax.py", line 19, in
import trax
File "/root/.local/lib/python3.6/site-packages/trax/init.py", line 19, in
from trax import lr_schedules as lr
File "/root/.local/lib/python3.6/site-packages/trax/lr_schedules.py", line 37, in
from trax import models as trax_models
File "/root/.local/lib/python3.6/site-packages/trax/models/init.py", line 32, in
from trax.models.research import bert
File "/root/.local/lib/python3.6/site-packages/trax/models/research/bert.py", line 20, in
from tensorflow.train import load_checkpoint
ModuleNotFoundError: No module named 'tensorflow.train'

If we change that (in bert.py) to
#from tensorflow.train import load_checkpoint
from tensorflow_core._api.v2.train import load_checkpoint

the import works.

My environment or setup? Sorry if so. Tried to exclude that but to no avail.

...

Environment information

OS: ubuntu 18.04 in a docker container

mesh-tensorflow==0.1.11
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0
# your output here

$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39
(I know I need to upgrade but I deem unrelated)

$ python3 -V
Python 3.6.9

For bugs: reproduction and error logs

# Steps to reproduce:

import trax


Error logs:

Traceback (most recent call last):
File "math_trax.py", line 19, in
import trax
File "/root/.local/lib/python3.6/site-packages/trax/init.py", line 19, in
from trax import lr_schedules as lr
File "/root/.local/lib/python3.6/site-packages/trax/lr_schedules.py", line 37, in
from trax import models as trax_models
File "/root/.local/lib/python3.6/site-packages/trax/models/init.py", line 32, in
from trax.models.research import bert
File "/root/.local/lib/python3.6/site-packages/trax/models/research/bert.py", line 20, in
from tensorflow.train import load_checkpoint
ModuleNotFoundError: No module named 'tensorflow.train'

Reformer Model for Speech Recognition

Coming from tensor2tensor I was wondering whether the Reformer model would be also a candidate for speech recognition? Looking at the examples there is none for ASR.

Would it be possible to train an ASR model on the Reformer or would code changes be necessary? If so, can we estimate how much would have to be changed on the model implementation?

Thank you for any insight into this!

Reformer TPU training

Description

Hello,

I train my Reformer model with parameters

n_encoder_layers = 3,
n_encoder_layers = 3,
d_model = 512,
ff_size = 2048,
attention_heads_num = 8,
dropout = 0.1,
max_len=250

on Google Colab TPU. My sequences are padded to the same length, so I feed this shapes to the Reformer:

x.shape = (batch_size, 256)
y.shape  = (batch_size, 128)

The batch size the TPU accepts is 2048, not more. Training step lasts for approximately 1.15 seconds. So, it means that I'm able to put 2048 * (256 + 128) = 786432 int64 numbers to all 8 TPU cores. If I choose bigger batch (e.g. 4096), it won't fit into memory and show me this error stack trace:

RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 14.93G of 8.00G hbm. Exceeded hbm capacity by 6.93G.

Total hbm usage >= 14.93G:
    reserved        529.00M 
    program          14.41G 
    arguments       unknown size 

Output size unknown.

Program hbm requirement 14.41G:
    reserved           4.0K
    global           100.0K
    HLO temp         14.41G (98.3% utilization: Unpadded (14.16G) Padded (14.40G), 0.0% fragmentation (7.25M))

  Largest program allocations in hbm:

  1. Size: 5.97G
     Shape: f32[512,100,30000]{2,1,0:T(8,128)}
     Unpadded size: 5.72G
     Extra memory due to padding: 250.62M (1.0x expansion)
     XLA label: %copy.1598 = f32[512,100,30000]{2,1,0:T(8,128)} copy(f32[512,100,30000]{0,2,1} %copy.1597)
     Allocation type: HLO temp
     ==========================

  2. Size: 5.72G
     Shape: f32[512,100,30000]{0,2,1}
     Unpadded size: 5.72G
     XLA label: %copy.1597 = f32[512,100,30000]{0,2,1} copy(f32[512,100,30000]{2,1,0:T(8,128)} %get-tuple-element.2671)
     Allocation type: HLO temp
     ==========================

  3. Size: 1.00G
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[512,8,256,256]{2,3,1,0:T(8,128)}
     Unpadded size: 1.00G
     XLA label: %fusion.15833 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256,256]{2,3,1,0:T(8,128)}) fusion(f32[512,256]{1,0:T(8,128)} %get-tuple-element.2417, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.727, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.233),...
     Allocation type: HLO temp
     ==========================

  4. Size: 256.00M
     Operator: op_type="dot_general" op_name="dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n             precision=None ]"
     Shape: f32[512,256,512]{2,1,0:T(8,128)}
     Unpadded size: 256.00M
     XLA label: %fusion.659 = f32[512,256,512]{2,1,0:T(8,128)} fusion(f32[512,256,512]{2,1,0:T(8,128)} %get-tuple-element.2643, f32[512]{0:T(512)} %get-tuple-element.2794, f32[512,512]{1,0:T(8,128)} %reshape.4754, f32[512,256,2048]{2,1,0:T(8,128)} %fusion.276, f32[512,204...
     Allocation type: HLO temp
     ==========================

  5. Size: 256.00M
     Operator: op_type="reduce_sum" op_name="pmap(mapped_update)/reduce_sum[ axes=(2,)\n                                input_shape=(512, 256, 512) ]"
     Shape: f32[512,256,512]{2,1,0:T(8,128)}
     Unpadded size: 256.00M
     XLA label: %fusion.15774 = (f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}, f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}) fusion(f32[512,256,512]{2,1,0:T(8,128)} %get-tuple-element.2319, f32[512,256,512]{2,1,0:T(8,128)} %fusion.659, f32[...
     Allocation type: HLO temp
     ==========================

  6. Size: 256.00M
     Operator: op_type="reduce_sum" op_name="reduce_sum[ axes=(2,)\n            input_shape=(512, 256, 512) ]"
     Shape: f32[512,256,512]{2,1,0:T(8,128)}
     Unpadded size: 256.00M
     XLA label: %fusion.15776 = (f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}) fusion(f32[512,256,512]{2,1,0:T(8,128)} %get-tuple-element.2337, f32[512,256,512]{2,1,0:T(8,128)} %fusion.33), kind=kLoop, calls=%fused_computation.15091, metadata={op_type="red...
     Allocation type: HLO temp
     ==========================

  7. Size: 256.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n                                 precision=None ]"
     Shape: f32[512,256,512]{2,1,0:T(8,128)}
     Unpadded size: 256.00M
     XLA label: %fusion.15832 = (f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}) fusion(f32[512,256,512]{2,1,0:T(8,128)} %fusion.659, f32[512,512]{1,0:T(8,128)} %reshape.4767, f32[512]{0:T(512)} %get-tuple-element.2794, f32[2048,512,1]{1,0,2:T(8,128)} %bitca...
     Allocation type: HLO temp
     ==========================

  8. Size: 128.00M
     Operator: op_type="gather" op_name="pmap(mapped_update)/gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))\n                            operand_shape=(512, 256, 512)\n                            slice_sizes=(1, 256, 512) ]"
     Shape: bf16[512,256,512]{2,1,0:T(8,128)(2,1)}
     Unpadded size: 128.00M
     XLA label: %fusion.5 = bf16[512,256,512]{2,1,0:T(8,128)(2,1)} fusion(bf16[512,256,512]{2,1,0:T(8,128)(2,1)} %fusion.593, s32[512]{0} %get-tuple-element.2747), kind=kCustom, calls=%fused_computation.5, metadata={op_type="gather" op_name="pmap(mapped_update)/gather[ di...
     Allocation type: HLO temp
     ==========================

  9. Size: 128.00M
     Operator: op_type="gather" op_name="gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))\n        operand_shape=(512, 256, 512)\n        slice_sizes=(1, 256, 512) ]"
     Shape: bf16[512,256,512]{2,1,0:T(8,128)(2,1)}
     Unpadded size: 128.00M
     XLA label: %fusion.10 = bf16[512,256,512]{2,1,0:T(8,128)(2,1)} fusion(bf16[512,256,512]{2,1,0:T(8,128)(2,1)} %fusion.660, s32[512]{0} %get-tuple-element.2748), kind=kCustom, calls=%fused_computation.10, metadata={op_type="gather" op_name="gather[ dimension_numbers=Ga...
     Allocation type: HLO temp
     ==========================

  10. Size: 128.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: bf16[512,8,256,64]{0,3,2,1:T(8,128)(2,1)}
     Unpadded size: 128.00M
     XLA label: %copy.980 = bf16[512,8,256,64]{0,3,2,1:T(8,128)(2,1)} copy(bf16[512,8,256,64]{2,3,1,0:T(8,128)(2,1)} %fusion.349), metadata={op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n               ...
     Allocation type: HLO temp
     ==========================

  11. Size: 100.00M
     Shape: f32[512,100,512]{2,0,1}
     Unpadded size: 100.00M
     XLA label: %copy.1046 = f32[512,100,512]{2,0,1} copy(f32[512,100,512]{2,1,0:T(8,128)} %get-tuple-element.2670)
     Allocation type: HLO temp
     ==========================

  12. Size: 100.00M
     Shape: f32[512,100,512]{2,0,1}
     Unpadded size: 100.00M
     XLA label: %copy.1044 = f32[512,100,512]{2,0,1} copy(f32[512,100,512]{2,1,0:T(8,128)} %get-tuple-element.2224)
     Allocation type: HLO temp
     ==========================

  13. Size: 100.00M
     Shape: f32[512,100,512]{2,0,1}
     Unpadded size: 100.00M
     XLA label: %copy.1052 = f32[512,100,512]{2,0,1} copy(f32[512,100,512]{2,1,0:T(8,128)} %get-tuple-element.2222)
     Allocation type: HLO temp
     ==========================

  14. Size: 4.00M
     Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
     Shape: f32[512,2048]{1,0:T(8,128)}
     Unpadded size: 4.00M
     XLA label: %reshape.4770 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1786), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
     Allocation type: HLO temp
     ==========================

  15. Size: 4.00M
     Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
     Shape: f32[512,2048]{1,0:T(8,128)}
     Unpadded size: 4.00M
     XLA label: %reshape.4773 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1785), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
     Allocation type: HLO temp
     ==========================

  16. Size: 4.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[512,8,256]{2,1,0:T(8,128)}
     Unpadded size: 4.00M
     XLA label: %fusion.15833 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256,256]{2,3,1,0:T(8,128)}) fusion(f32[512,256]{1,0:T(8,128)} %get-tuple-element.2417, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.727, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.233),...
     Allocation type: HLO temp
     ==========================

  17. Size: 4.00M
     Operator: op_type="reduce_sum" op_name="pmap(mapped_update)/reduce_sum[ axes=(3,)\n                                input_shape=(512, 8, 256, 256) ]"
     Shape: f32[512,8,256]{2,1,0:T(8,128)}
     Unpadded size: 4.00M
     XLA label: %fusion.366 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256]{2,1,0:T(8,128)}) fusion(f32[512,8,256,256]{2,3,1,0:T(8,128)} %get-tuple-element.2649, f32[512,8,256]{2,1,0:T(8,128)} %get-tuple-element.2648, f32[512,8,256]{2,1,0:T(8,128)} %fusion.1865), kind=k...
     Allocation type: HLO temp
     ==========================

  18. Size: 4.00M
     Operator: op_type="reduce_sum" op_name="pmap(mapped_update)/reduce_sum[ axes=(3,)\n                                input_shape=(512, 8, 256, 256) ]"
     Shape: f32[512,8,256]{2,1,0:T(8,128)}
     Unpadded size: 4.00M
     XLA label: %fusion.366 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256]{2,1,0:T(8,128)}) fusion(f32[512,8,256,256]{2,3,1,0:T(8,128)} %get-tuple-element.2649, f32[512,8,256]{2,1,0:T(8,128)} %get-tuple-element.2648, f32[512,8,256]{2,1,0:T(8,128)} %fusion.1865), kind=k...
     Allocation type: HLO temp
     ==========================

  19. Size: 4.00M
     Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
     Shape: f32[512,2048]{1,0:T(8,128)}
     Unpadded size: 4.00M
     XLA label: %reshape.4764 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1789), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
     Allocation type: HLO temp
     ==========================

  20. Size: 4.00M
     Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
     Shape: f32[512,2048]{1,0:T(8,128)}
     Unpadded size: 4.00M
     XLA label: %reshape.4766 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1788), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
     Allocation type: HLO temp
     ==========================

However, I wonder how the authors of this quite popular paper - https://ufal.mff.cuni.cz/pbml/110/art-popel-bojar.pdf were able to fit bigger batches to the GTX 1080 Ti GPU and managed to reach such high throughput for a single GPU (page 9, for batch size 500 they have 43400 steps per hour (21 700 000 examples per hour)), whereas the Reformer model has only (2048 / 1.15) * 60 * 60 = 4 915 199 examples per hour? Am I mistaken or what am I doing wrong?

Thanks.

Hallo,if possible,it is appreciate for your to upload a gin to train t2t_translate_ende_wmt32k with LSHSelfAttention.

Description

Hallo,if possible,it is appreciate for your to upload a gin to train t2t_translate_ende_wmt32k with LSHSelfAttention. i have tried for many time ,but still return error.
...

Environment information

OS: <your answer here>

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...

Pre-trained Reformer models

It would be great if pretrained Reformer models become available (e.g., trained on BooksCorpus and English Wikipedia).

Trax doesn't detect any gpu devices

Description

I trying to train ReformerLM model from this tutorial

https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb#scrollTo=vlGjGoGMTt-D

and cont't feed reinformer to gpu

Environment information

OS: ubuntu 18.04

$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.0.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39

$ python -V
python -V

### For bugs: reproduction and error logs

Steps to reproduce:

pip install -q -U trax
pip install -q tensorflow

import trax
from trax.models import ReformerLM
import os
import numpy as np
import tensorflow as tf
import jax

def copy_task(batch_size, vocab_size, length):
"""This task is to copy a random string w, so the input is 0w0w."""
while True:
assert length % 2 == 0
w_length = (length // 2) - 1
w = np.random.randint(low=1, high=vocab_size-1,
size=(batch_size, w_length))
zero = np.zeros([batch_size, 1], np.int32)
loss_weights = np.concatenate([np.zeros((batch_size, w_length)),
np.ones((batch_size, w_length+2))], axis=1)
x = np.concatenate([zero, w, zero, w], axis=1)
yield (x, x, loss_weights) # Here inputs and targets are the same.
copy_inputs = trax.supervised.Inputs(lambda _: copy_task(16, 32, 10))

data_stream = copy_inputs.train_stream(1)
inputs, targets, mask = next(data_stream)
print("Inputs[0]: %s" % str(inputs[0]))
print("Targets[0]: %s" % str(targets[0]))
print("Mask[0]: %s" % str(mask[0]))

def tiny_transformer_lm(mode):
return trax.models.TransformerLM( # You can try trax_models.ReformerLM too.
d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)

output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl # Remove old model.
trainer = trax.supervised.Trainer(
model=tiny_transformer_lm,
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adafactor, # Change optimizer params here.
lr_schedule=trax.lr.MultifactorSchedule, # Change lr schedule here.
inputs=copy_inputs,
output_dir=output_dir,
has_weights=True) # Because we have loss mask, this API may change.

n_epochs = 3
train_steps = 500
eval_steps = 2
for _ in range(n_epochs):
trainer.train_epoch(train_steps, eval_steps)

Error logs:

/opt/anaconda/envs/trax_3_7/lib/python3.7/site-packages/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')

if i try to get jax.devices() directly:

[CpuDevice(id=0)]

but, tensorflow haven't problem with gpu detection

tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
PhysicalDevice(name='/physical_device:XLA_GPU:1', device_type='XLA_GPU')]

[Bug] import trax fails with "ImportError: cannot import name 'HistogramProto'"

Description

Fails on import trax

Environment information

Singularity image bootstrapped from "docker ubuntu:latest"

OS: Ubuntu 18.04 LTS

$ pip freeze | grep tensor
mesh-tensorflow==0.1.7
tensor2tensor==1.15.2
tensorboard==2.0.2
tensorflow==2.0.0
tensorflow-datasets==1.3.2
tensorflow-estimator==2.0.1
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.1
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.55
jaxlib==0.1.37

$ python -V
Python 3.6.9

Steps to reproduce:

python3 -c 'import trax'

Error logs:

Traceback (most recent call last):
File "", line 1, in
File "/venv/lib/python3.6/site-packages/trax/init.py", line 21, in
from trax import learning_rate as lr
File "/venv/lib/python3.6/site-packages/trax/learning_rate.py", line 294, in
from trax.rl import online_tune
File "/venv/lib/python3.6/site-packages/trax/rl/init.py", line 24, in
from trax.rl import simulated_env_problem
File "/venv/lib/python3.6/site-packages/trax/rl/simulated_env_problem.py", line 29, in
from trax import trainer_lib
File "/venv/lib/python3.6/site-packages/trax/trainer_lib.py", line 41, in
from trax import jaxboard
File "/venv/lib/python3.6/site-packages/trax/jaxboard.py", line 38, in
from tensorflow import HistogramProto
ImportError: cannot import name 'HistogramProto'

Can't find ptxas binary in ${CUDA_DIR}/bin.

Description

Try to run the reformer model with the configuration reformer_enwik8.gin. Get an error: Can't find ptxas binary in ${CUDA_DIR}/bin.
...

Environment information

OS: Ubuntu 18.04.3 LTS

$ pip freeze | grep tensor
mesh-tensorflow==0.1.7
tensor2tensor==1.15.4
tensorboard==1.15.0
tensorflow-datasets==1.3.2
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.2
tensorflow-probability==0.7.0
tensorrt==6.0.1.4

$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37

$ python -V
python 3.6.8

$ nvcc --version 
cuda10.0 (/usr/local/cuda --> /usr/local/cuda-10.0, but /usr/local/cuda-10.1 exists)

GPU: 2080TI * 4

For bugs: reproduction and error logs

# Steps to reproduce:
Just run the trainer.py in trax/trax using the configuration reformer_enwiki8.gin.
# Error logs:
[[[!!!! I remove some normal info about dataset]]]
I0119 09:32:55.178084 140128464549696 problem.py:651] Reading data files from /root/tensorflow_datasets/t2t_enwik8_l65k/enwik8_l65k-dev*
INFO:tensorflow:partition: 0 num_data_files: 1
I0119 09:32:55.179685 140128464549696 problem.py:677] partition: 0 num_data_files: 1
I0119 09:32:56.124050 140128464549696 inputs.py:443] Heuristically setting bucketing to False based on shapes of target tensors.
I0119 09:32:56.131589 140128464549696 inputs.py:443] Heuristically setting bucketing to False based on shapes of target tensors.
I0119 09:32:56.136316 140128464549696 inputs.py:443] Heuristically setting bucketing to False based on shapes of target tensors.
I0119 09:33:05.191175 140128464549696 trainer_lib.py:754] Model loaded from ../checkpoints/model.pkl at step 0
Model loaded from ../checkpoints/model.pkl at step 0
I0119 09:33:05.192780 140128464549696 trainer_lib.py:754] Step      0: Starting training using 1 devices
Step      0: Starting training using 1 devices
I0119 09:33:05.194077 140128464549696 trainer_lib.py:754] Step      0: Total number of trainable weights: 215865602
Step      0: Total number of trainable weights: 215865602

2020-01-19 09:33:09.105234: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.105464: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.105489: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.105517: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:09.105532: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:09.105554: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:09.105567: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.193084: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.193291: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.193319: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.193338: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:09.193354: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:09.193384: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:09.193418: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.345517: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.345708: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.345732: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.345749: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:09.345762: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:09.345776: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:09.345790: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.440697: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.440881: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.440903: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.440918: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:09.440930: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:09.440941: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:09.440954: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.545554: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.545752: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.545774: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.545791: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:09.545804: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:09.545815: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:09.545827: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.730990: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.731233: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.731260: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.731279: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:09.731293: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:09.731305: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:09.731319: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:10.081432: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:10.081621: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:10.081644: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:10.081659: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:10.081671: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:10.081708: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:10.081721: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:13.557328: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:13.557530: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:13.557552: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:13.557567: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:13.557578: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:13.557589: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:13.557601: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:13.633426: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:13.633613: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:13.633636: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:13.633651: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:13.633663: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:13.633700: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:13.633713: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:13.709584: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:13.709778: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:13.709801: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:13.709815: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:13.709826: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:13.709839: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:13.709876: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:14.256316: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:14.256517: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin.  Will back to the GPU driver for PTX -> sass compilation.  This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:14.256540: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:14.256556: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   ./cuda_sdk_lib
2020-01-19 09:33:14.256568: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   /usr/local/cuda
2020-01-19 09:33:14.256579: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77]   .
2020-01-19 09:33:14.256591: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:31.094227: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:31.094430: W external/org_tensorflow/tensorflow/stream_executor/gpu/redzone_allocator.cc:312] Internal: Failed to launch ptxas
Relying on driver to perform ptx compilation. This message will be only logged once.
2020-01-19 09:33:31.177827: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:31.255405: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
Traceback (most recent call last):
  File "/home/xxx/pycharm_proj/trax/trax/trainer.py", line 195, in <module>
    app.run(main)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "/home/xxx/pycharm_proj/trax/trax/trainer.py", line 189, in main
    trainer_lib.train(output_dir=output_dir)
  File "/usr/local/lib/python3.6/dist-packages/gin/config.py", line 1078, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/local/lib/python3.6/dist-packages/gin/utils.py", line 49, in augment_exception_message_and_reraise
    six.raise_from(proxy.with_traceback(exception.__traceback__), None)
  File "<string>", line 3, in raise_from
  File "/usr/local/lib/python3.6/dist-packages/gin/config.py", line 1055, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/xxx/pycharm_proj/trax/trax/supervised/trainer_lib.py", line 641, in train
    trainer.train_epoch(epoch_steps, eval_steps)
  File "/home/xxx/pycharm_proj/trax/trax/supervised/trainer_lib.py", line 305, in train_epoch
    self.train_step(batch)
  File "/home/xxx/pycharm_proj/trax/trax/supervised/trainer_lib.py", line 337, in train_step
    self._step, opt_state, batch, self._model_state, self._rngs)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 149, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 602, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 442, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
  File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 223, in memoized_fun
    ans = call(fun, *args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 499, in _xla_callable
    compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend))
  File "/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py", line 609, in Compile
    return backend.compile(self.computation, compile_options)
  File "/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py", line 161, in compile
    compile_options.device_assignment)
RuntimeError: Internal: Failed to launch ptxas

Segmentation fault error when using tensorflow dataset

Description

Hi all, I'm building dataset using tensorflow and trax on Ubuntu docker. But I encountered Segmentation fault error.
When I run the code without trax, there is no error. Please help me.

Environment information (Dockerfile)

FROM tensorflow/tensorflow:latest-gpu-py3

RUN apt-get -y update
RUN apt-get -y upgrade
RUN apt-get install -y less wget git
# for error of matplotlib + trax 
RUN apt-get install -y python3-cairocffi python3-gi gir1.2-gtk-3.0

RUN pip install -U pip
RUN pip install -U six
RUN pip install -U matplotlib==3.1.3
RUN pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.43-cp36-none-linux_x86_64.whl
RUN pip install --upgrade jax
WORKDIR /tmp/docker_works
RUN git clone https://github.com/google/trax.git
WORKDIR /tmp/docker_works/trax
RUN sed -i '1s/^/import tensorflow\n/' ./trax/models/research/bert.py
RUN sed -i -e "s/from tensorflow.train import load_checkpoint//g" ./trax/models/research/bert.py
RUN sed -i -e "s/load_checkpoint/tensorflow.train.load_checkpoint/g" ./trax/models/research/bert.py
RUN python setup.py install
WORKDIR /tmp/docker_works

For bugs: reproduction and error logs

code

import matplotlib as mlp
mlp.use('Agg')
import trax
import faulthandler
faulthandler.enable()
import pickle
import random
import numpy as np
import tensorflow as tf

if __name__ == "__main__":
	with tf.io.TFRecordWriter('./data/tmp.tfrecord') as writer:
		for i in range(10):
			example = tf.train.Example(features=tf.train.Features(
				feature = {'input_ids':tf.train.Feature(int64_list=tf.train.Int64List(value=range(10))),
						   'labels':tf.train.Feature(int64_list=tf.train.Int64List(value=range(10)))
						  }
			))
			writer.write(example.SerializeToString())
	dataset = tf.data.TFRecordDataset('./data/tmp.tfrecord')
	print(dataset)

Error logs:

log with trax

$ python script/sample_with_trax.py
Fatal Python error: Segmentation fault

Current thread 0x00007eff32052740 (most recent call first):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/context.py", line 1081 in _initialize_physical_devices
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/context.py", line 815 in config
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/context.py", line 496 in ensure_initialized
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 95 in convert_to_eager_tensor
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 266 in _constant_impl
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 258 in constant
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 317 in _constant_tensor_conversion_function
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1302 in convert_to_tensor
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/data/ops/readers.py", line 55 in _create_or_validate_filenames_dataset
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/data/ops/readers.py", line 316 in __init__
  File "script/sample_with_trax.py", line 20 in <module>
Segmentation fault (core dumped)

log without trax (comment out import trax)

$ python script/sample_with_trax.py
<TFRecordDatasetV2 shapes: (), types: tf.string>

Maintained Documentation

Description

Trax is a library for deep learning that focuses on sequence models and reinforcement learning. It combines performance with code clarity and maintained documentation and tests.
...

Sorry to bother, I'll be brief. I don't think the "maintained documentation" part of the statement is true (yet?). I like the work and I respect every project that goes deep down on neural network implementation, but I feel there is a critical lack of documentation for this project.

I was giving a look at Flax read the docs and, although the projects have different motives, I believe there should be something alike for Trax.

Again, sorry to bother. Wish all the luck and success for the project.

ReformerLM document representation

Description

I have managed to adapt the colab code for learning document representations and the training and generation phase works smoothly. I adapted the sample() method to return the final state after processing the document. However this final state seems to be a complex list consisting of a variety of information. What I want is the hidden representation of the topmost layer which I am assuming represents the whole document. Is there any way to obtain said hidden representation? I am providing my part of the code that is relevant. Any suggestions will be appreciated.
...

Environment information

OS: Ubuntu 16.04 (Irrelevant)

$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.14.0
tensorboard==1.15.0
tensorflow-datasets==1.3.2
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.0
tensorflow-probability==0.7.0


$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37


$ python -V
Python 3.6.8

For bugs: reproduction and error logs

# Steps to reproduce:
The relevant part of the code I wrote:

# Prepare a jitted copy of the model.
jit_model_infer = trax.layers.base._accelerate(
    model_infer._forward_internal, trax.math.device_count())
# Set up the initial state for sampling.
infer_state = model_infer.new_weights_and_state(
    trax.supervised.trainer_lib.ShapeDtype((1,1), dtype=np.int32))[1]
infer_state = trainer._for_n_devices(infer_state)

def docvector(length=0, prompt=None):
  """Sample from the ReformerLM model"""
  model_weights = trainer._opt_state[0][0]
  length = len(prompt.split(" "))
  # Token id 0 is the equivalent of a "start" token
  cur_inputs = np.zeros((trax.math.device_count(), 1, 1), dtype=np.int32)

  cur_state = infer_state
  rngs = trax.math.random.split(trax.math.random.get_prng(0), trax.math.device_count())
  all_samples = []

  prompt = np.asarray(
        [TOKENIZER.EncodeAsIds(prompt)] * trax.math.device_count()) <--------- Prompt is the input document as a string.
  logits, cur_state = jit_model_infer(
        cur_inputs,
        model_weights,
        cur_state,
        rngs)
  for iteration in range(length):
    cur_samples = onp.array(prompt[:, iteration], dtype=int)
    cur_inputs = np.array(cur_samples[:,None,None])
    logits, cur_state = jit_model_infer(
        cur_inputs,
        model_weights,
        cur_state,
        rngs)
  
  return cur_state  <-------------------- This is a list of lists/dictionaries/tensors/tuples. How do I get the final hidden state?


# Error logs:
N/A

Undefined names in lax_numpy_test.py

Description

flake8 testing of https://github.com/google/trax on Python 3.8.0

$ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics

./trax/tf_numpy/jax_tests/lax_numpy_test.py:588:12: F821 undefined name 'FLAGS'
    if not FLAGS.jax_enable_x64 and any(
           ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:740:17: F821 undefined name 'dtypes'
    tol_spec = {dtypes.bfloat16: 3e-1, onp.float16: 0.15}
                ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1574:13: F821 undefined name 'dtypes'
    dtype = dtypes.canonicalize_dtype(dtype)
            ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1618:13: F821 undefined name 'api'
    csame = api.jit(same)
            ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1651:11: F821 undefined name 'api'
    fun = api.jit(fun)
          ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1663:11: F821 undefined name 'api'
    fun = api.jit(fun)
          ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1671:6: F821 undefined name 'api'
    @api.jit
     ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1691:42: F821 undefined name 'api'
    self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
                                         ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1692:42: F821 undefined name 'api'
    self.assertRaises(TypeError, lambda: api.jit(f)(x, y))
                                         ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1697:6: F821 undefined name 'api'
    @api.jit
     ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1705:6: F821 undefined name 'api'
    @api.jit
     ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1723:12: F821 undefined name 'api'
    cfoo = api.jit(foo)
           ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1820:9: F821 undefined name 'lax'
    x = lax.add(lnp.eye(3, dtype=lnp.float_), 0.)
        ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2006:23: F821 undefined name 'dtypes'
    dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
                      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2191:14: F821 undefined name 'api'
    result = api.grad(test_fail)(x)
             ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2220:39: F821 undefined name 'api'
    self.assertAllClose(onp.int64(7), api.jit(lambda x: x)(onp.longlong(7)),
                                      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2249:50: F821 undefined name 'lax'
    self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77)))
                                                 ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2255:26: F821 undefined name 'lax'
                    type(lax.iota(onp.int32, 77)))
                         ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2274:9: F821 undefined name 'api'
    f = api.grad(lambda x: lnp.sum(lnp.tanh(x)))
        ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2284:11: F821 undefined name 'jax'
      y = jax.ops.index_add(onp.ones(10,), [2, 4, 5], u)
          ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2287:14: F821 undefined name 'lax'
      return lax.tie_in(y, 7.)
             ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2289:40: F821 undefined name 'api'
    self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
                                       ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2295:9: F821 undefined name 'api'
    f = api.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x))))
        ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2311:23: F821 undefined name 'dtypes'
    dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
                      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2333:14: F821 undefined name 'api'
    @partial(api.jit, static_argnums=(1,))
             ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2520:13: F821 undefined name 'FLAGS'
        not FLAGS.jax_enable_x64):
            ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2604:19: F821 undefined name 'FLAGS'
      prev_flag = FLAGS.jax_numpy_rank_promotion
                  ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2605:7: F821 undefined name 'FLAGS'
      FLAGS.jax_numpy_rank_promotion = "allow"
      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2608:7: F821 undefined name 'FLAGS'
      FLAGS.jax_numpy_rank_promotion = prev_flag
      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2611:19: F821 undefined name 'FLAGS'
      prev_flag = FLAGS.jax_numpy_rank_promotion
                  ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2612:7: F821 undefined name 'FLAGS'
      FLAGS.jax_numpy_rank_promotion = "raise"
      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2615:7: F821 undefined name 'FLAGS'
      FLAGS.jax_numpy_rank_promotion = prev_flag
      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2618:19: F821 undefined name 'FLAGS'
      prev_flag = FLAGS.jax_numpy_rank_promotion
                  ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2619:7: F821 undefined name 'FLAGS'
      FLAGS.jax_numpy_rank_promotion = "warn"
      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2633:7: F821 undefined name 'FLAGS'
      FLAGS.jax_numpy_rank_promotion = prev_flag
      ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2638:6: F821 undefined name 'api'
    @api.jit
     ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2643:6: F821 undefined name 'api'
    @api.jit
     ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2656:15: F821 undefined name 'jax'
      y = y + jax.grad(lambda z: lnp.sum(lnp.maximum(z, 0.)))(x)
              ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2659:19: F821 undefined name 'lax'
    f = lambda y: lax.fori_loop(0, 5, body, (y, y))
                  ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2660:15: F821 undefined name 'linear_util'
    wrapped = linear_util.wrap_init(f)
              ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2661:10: F821 undefined name 'partial_eval'
    pv = partial_eval.PartialVal(
         ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2662:8: F821 undefined name 'jax'
      (jax.ShapedArray((3, 4), onp.float32), jax.core.unit))
       ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2662:46: F821 undefined name 'jax'
      (jax.ShapedArray((3, 4), onp.float32), jax.core.unit))
                                             ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2663:20: F821 undefined name 'partial_eval'
    _, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv])
                   ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2709:15: F821 undefined name 'lax'
    HIGHEST = lax.Precision.HIGHEST
              ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2786:20: F821 undefined name 'dtypes'
  return lnp.finfo(dtypes.canonicalize_dtype(dtype)).bits
                   ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2803:5: F821 undefined name 'check_grads'
    check_grads(op, args, order, ["fwd", "rev"], tol, tol)
    ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2813:5: F821 undefined name 'check_grads'
    check_grads(op, (special_value,), order, ["fwd", "rev"],
    ^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2825:5: F821 undefined name 'check_grads'
    check_grads(f, (1.,), order=1)
    ^
49    F821 undefined name 'FLAGS'
49

https://flake8.pycqa.org/en/latest/user/error-codes.html

On the flake8 test selection, this PR does not focus on "style violations" (the majority of flake8 error codes that psf/black can autocorrect). Instead these tests are focus on runtime safety and correctness:

  • E9 tests are about Python syntax errors usually raised because flake8 can not build an Abstract Syntax Tree (AST). Often these issues are a sign of unused code or code that has not been ported to Python 3. These would be compile-time errors in a compiled language but in a dynamic language like Python they result in the script halting/crashing on the user.
  • F63 tests are usually about the confusion between identity and equality in Python. Use ==/!= to compare str, bytes, and int literals is the classic case. These are areas where a == b is True but a is b is False (or vice versa). Python >= 3.8 will raise SyntaxWarnings on these instances.
  • F7 tests logic errors and syntax errors in type hints
  • F82 tests are almost always undefined names which are usually a sign of a typo, missing imports, or code that has not been ported to Python 3. These also would be compile-time errors in a compiled language but in Python a NameError is raised which will halt/crash the script on the user.

Environment information

OS: <your answer here>

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...

flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics

# Error logs:
...

Segmentation fault on import trax

Description

I'm getting a segmentation fault when trying to import trax.
Has anyone encountered the same problem?

Environment information

OS: Ubuntu 18.04.3 LTS
Docker image: tensorflow/tensorflow:2.1.0-gpu-py3

$ pip freeze | grep tensor
mesh-tensorflow==0.1.12
tensor2tensor==1.15.4
tensorboard==2.1.1
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-gpu==2.1.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.40

$ pip freeze | grep trax
trax==1.2.3

$ pip freeze | grep matplotlib
matplotlib==3.2.0
# also tried matplotlib==2.2.5 with same results

$ python -V
Python 3.6.9

$ lshw -C display
  *-display                 
       description: VGA compatible controller
       product: GP104 [GeForce GTX 1080]
       vendor: NVIDIA Corporation
       physical id: 0
       bus info: pci@0000:01:00.0
       version: a1
       width: 64 bits
       clock: 33MHz
       capabilities: vga_controller bus_master cap_list rom
       configuration: driver=nvidia latency=0
       resources: irq:126 memory:ee000000-eeffffff memory:d0000000-dfffffff memory:e0000000-e1ffffff ioport:e000(size=128) memory:ef000000-ef07ffff

For bugs: reproduction and error logs

# Steps to reproduce:
Python 3.6.9 (default, Nov  7 2019, 10:44:02) 
[GCC 8.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> tf.__version__
'2.1.0'
>>> print(tf.config.list_physical_devices('GPU'))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
>>> import trax
Segmentation fault (core dumped)
Python 3.6.9 (default, Nov  7 2019, 10:44:02)                     
[GCC 8.3.0] on linux                                                       
Type "help", "copyright", "credits" or "license" for more information.
>>> import faulthandler                                                      
>>> faulthandler.enable()                                            
>>> import trax                                                                             
Fatal Python error: Segmentation fault                                       
                                                                        
Thread 0x00007f77d6bb5740 (most recent call first):               
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 1007 in addfont
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 991 in __init__
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 1334 in _rebuild
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 1343 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed                 
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module     
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked      
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load         
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/contour.py", line 16 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed                      
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module     
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked      
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load         
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/colorbar.py", line 31 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked                    
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked  
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load      
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/pyplot.py", line 32 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked                       
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked  
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load      
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist      
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/__init__.py", line 1258 in use
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py", line 358 in wrapper
  File "/tmp/env/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py", line 296 in wrapper
  File "/tmp/env/lib/python3.6/site-packages/tensor2tensor/data_generators/video_generated.py", line 35 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "<frozen importlib._bootstrap>", line 994 in _gcd_import
  File "/usr/lib/python3.6/importlib/__init__.py", line 126 in import_module
  File "/tmp/env/lib/python3.6/site-packages/tensor2tensor/data_generators/all_problems.py", line 140 in import_modules
  File "/tmp/env/lib/python3.6/site-packages/tensor2tensor/problems_colab.py", line 36 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
  File "/tmp/env/lib/python3.6/site-packages/trax/supervised/inputs.py", line 31 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
  File "/tmp/env/lib/python3.6/site-packages/trax/supervised/__init__.py", line 18 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "/tmp/env/lib/python3.6/site-packages/trax/rl/simulated_env_problem.py", line 35 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
  File "/tmp/env/lib/python3.6/site-packages/trax/rl/__init__.py", line 24 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "/tmp/env/lib/python3.6/site-packages/trax/learning_rate.py", line 294 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
  File "/tmp/env/lib/python3.6/site-packages/trax/__init__.py", line 19 in <module>
  File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
  File "<frozen importlib._bootstrap_external>", line 678 in exec_module
  File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
  File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 971 in _find_and_load
  File "<stdin>", line 1 in <module>
Segmentation fault (core dumped)

Number of inputs Serial.forward error

Description

Got inputs layer issue testing Resnet50 with cifar10 as toy example:
ValueError: number of inputs (2) to Serial.forward less than n_in (3).

https://gist.github.com/rodrigobaron/e0874af5e8e32b18411fa4bb30e49174
...

Environment information

Jax==0.1.52
Trax==1.2.2
Tensorflow==1.15.0

OS: Linux (Google Colab)

For bugs: reproduction and error logs

# Steps to reproduce:
Import https://gist.github.com/rodrigobaron/e0874af5e8e32b18411fa4bb30e49174 on Google Colab and run with GPU runtime.
# Error logs:
LayerError: Exception passing through layer Serial (in _forward_internal):
  layer created in file [...]/trax/supervised/trainer_lib.py, line 674
  layer input shapes: (ShapeDtype{shape:(32, 32, 32, 3), dtype:float32}, ShapeDtype{shape:(32, 10), dtype:float32})

  File [...]/trax/layers/combinators.py, line 59, in forward_with_state
    self._validate_forward_inputs(xs)

  File [...]/trax/layers/combinators.py, line 137, in _validate_forward_inputs
    ' ({})'.format(len(xs), self.n_in))

ValueError: number of inputs (2) to Serial.forward less than n_in (3)

EagerTensor is not a valid JAX type

Description

Hello, I'm facing a problem while trying to work with Trax Trainer class. I have loaded my dataset from TFRecords file and created a Dataset instance using Dataset API. Then, I try to feed my dataset to the Trax trainer, but got this error. Could you please tell me how to accomplish this? I haven't found anything explaining how to use Dataset API with Trax library. Thanks!

Environment information

OS: Google Colab notebook

For bugs: reproduction and error logs

Steps to reproduce:

Pass dataset iterator to tras.Inputs class

Error logs:

TypeError: Argument '[[ 2 16 9 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
...
[ 2 70 21 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
[ 2 47 14 ... 0 0 0]]' of type <class 'tensorflow.python.framework.ops.EagerTensor'> is not a valid JAX type

Codelabs!

As per our (currently my) discussion in the trax gitter there is significant interest in there being Trax codelabs. Here's a live prototype or a screenshot to get the feel:

Screenshot 2020-02-29 at 6 49 04 PM

As I mentioned in the Trax gitter while this might mirror the content of a notebook someone could just use on colab it could potentially provide some added benefit by helping to reduce cognitive load. But also as I mentioned there there can be colabs as well codelabs compiled from the colabs 😉

E.g. a notebook and a markdown that is being used to generate the codelab in the linked prototype

Hosting this on our docs site just for demo purposes. Prototype was a fork of the open-wc codelab package which was generated using https://github.com/googlecodelabs/tools iiuc.

Related but separate is the possibility of there being a trax docs site? 😏 We did ours with Vuepress (again forked from openwc) which is working out really well.

Possible padding bug in Resnet50

Description

I think in resnet.py, the padding option for MaxPool should be 'SAME'. The shape of the output of MaxPool and the Resnet50ConvBlock right after it becomes B x 55 x 55 x C instead of B x 56 x 56 x C. See Keras and PyTorch.

...

Environment information

OS: <your answer here>

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...

Is there any plan to merge with Flax?

Apologies in advance for the question that might seem off topic.

Given the early breakout of JAX, it seems there is no convergence yet on an high level library for deep and reinforcement learning.

Do you personally have any plan to merge with flax?
The question comes from the will to contribute efficiently to a these libraries, where efficiency the probability to be superseeded.

Support layer custom names

Description

It is extremely difficult to debug nested Serial layer's stack, especially when I'm using layers like Branch and SerialWithSideOutputs (the layers that are built from other basic combinators), because the error stack shows them as just Serial layer.

I've made some small changes to base.Layer and combinators.py so that base.Layer supports overriding layer names (self.__class__.__name__) with user supplied names, and it helped debugging large models a lot.

# base.py
class Layer(object):
  def __init__(self, n_in=1, n_out=1, name=None):  # Added name
    self._name = name or self.__class__.__name__
    ...
# Replace self.__class__.__name__ in LayerError calls with self._name

# combinators.py
def Branch(*layers, name='Branch'):
  return Serial(..., name=name)

If this seems OK, I'd be glad to make a PR for this.

Reformer example on colab: Other book fail.

Description

I generated (size=320) bpe vocab and model files and compared with Crimes & Punishment files and everything went OK
Number of tokens: 750515 and (device count, tokens per device) = (8, 1048576)
until training:
RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 18.74G of 8.00G hbm. Exceeded hbm capacity by 10.74G.
Wow, a 50% tokens increase jumps 512M to 18G?
Or did I miss something else?

Environment information

https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb

trax.math.device_count() returning 1 whereas I have 4 GPU Tesla V100

Description

...

Environment information

OS: Ubuntu 18.04/NVDIA DGX Station (Desktop)

$ pip freeze | grep tensor

bert-tensorflow==1.0.1
mesh-tensorflow==0.1.9
tensor2tensor==1.15.4
tensorboard==2.0.2
tensorboardX==1.9
tensorflow-datasets==1.3.2
tensorflow-estimator==2.0.1
tensorflow-gan==2.0.0
tensorflow-gpu==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.1
tensorflow-probability==0.7.0

$ pip freeze | grep jax

jax==0.1.57
jaxlib==0.1.37

$ python -V
Python 3.7.5

For bugs: reproduction and error logs

# Steps to reproduce:
Text Generation code of on own machine 
https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb#scrollTo=PdAwmpS220ub

Set up the data pipeline.

def my_inputs(n_devices):
  while True:
    inputs = []
    mask = []
    pad_amounts = onp.random.choice(PAD_AMOUNT, n_devices)
    for i in range(n_devices):
      inputs.append(onp.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                            mode='constant'))
      mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32),
                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                          mode='constant'))
    inputs = onp.stack(inputs)
    mask = onp.stack(mask)
    yield (inputs, inputs, mask)

print("(device count, tokens per device) = ",
      next(my_inputs(trax.math.device_count()))[0].shape)

Error logs:

...

(device count, tokens per device) =  (1, 524288)
/home/sn/anaconda3/envs/py37/lib/python3.7/site-packages/jax/lib/xla_bridge.py:119: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

How to export tensorflow model

Description

Could mode.pkl transfer to tensorflow model?
...

Environment information

OS: <your answer here>

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...

[Bug] loss_fn argument for Trainer must not be a function since 1.2.4

Description

As in previous examples shown, loss_fn should be callable like this:

trainer = trax.supervised.Trainer(
    model=eval(train_model.selector),
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=trax.supervised.inputs.Inputs(train_stream),
    output_dir=output_dir,
)

However, since the latest upgrade to 1.2.4 this cannot not work anymore.

In the trainer_lib the loss_fn gets passed to a Serial constructor:

m = tl.Serial(model(mode='train'), loss_fn)

Which in turn runs _ensure_flat in it's constructor

sublayers = _ensure_flat(sublayers)

However, all objects in layers have to be of type base.Laser:

def _ensure_flat(layers):
  """Ensures that layers is a single flat list of Layer instances."""
  if len(layers) == 1 and layers[0] is None:
    layers = ()
  else:
    layers = _deep_flatten(layers)
  for obj in layers:
    if not isinstance(obj, base.Layer):
      raise ValueError(
          f'Found nonlayer object ({obj}) in layers: {layers}')
  return layers

See

def _ensure_flat(layers):

Thus we'll see an exception:

ValueError: Found nonlayer object (<function CrossEntropyLoss at 0x7fc5be59a9e0>) in layers:

running reformer with multiple batches

Very interesting and useful library! Thanks! My question is: How one arranges my_input function in order to run multiple batches using reformer. The text generation Colab only covers one batch. I got some useful information from configs (batch_fn) but still the arrangement of input is not clear. I have a sequence with 4M tokens and a 50000 vocabsize for a language model problem.

Proper comparison on real-world applications

I'd like to see a proper comparison against transformer (GPT-2) on text-generation with the same number of parameters. I'd like to see how it compares against when trained on sequences with the same length, and when Reformer uses a bigger context window.

Thanks a lot for your unique contribution, but substantial empirical and qualitative evidence still lacks.

Failed to sample from the Reformer model

Description

Failed to sample from the Reformer model after training on my local machine.
No codes changed.

Environment information

OS: Ubuntu 16.04 LTS

$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.14.1
tensorboard==1.15.0
tensorboardcolab==0.0.22
tensorflow-datasets==2.0.0
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-privacy==0.2.2
tensorflow-probability==0.7.0
tensorflow-serving-api-gpu==1.13.0
tensorflow-tensorboard==0.4.0

$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37

$ python -V
Python 3.6.9 :: Anaconda, Inc.

For bugs: reproduction and error logs

# Steps to reproduce:
run Text Generation on the local machine
https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb#scrollTo=favRDt3U4CJY

# Error logs:
...
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/base.py in _forward_internal(self, x, weights, state, rng)
    452         outputs, s = self.forward_with_state(
--> 453             x, weights=weights, state=state, rng=rng)
    454       else:

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/combinators.py in forward_with_state(self, xs, weights, state, **kwargs)
     59     self._validate_forward_inputs(xs)
---> 60     rngs = _pop_rng_and_split(kwargs, self._n_layers)
     61     if not self.sublayers:  # No-op: leave args unchanged.

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/combinators.py in _pop_rng_and_split(args_dict, n_copies)
    688     return (None,) * n_copies
--> 689   return math.random.split(rng, n_copies)
    690 

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/math/backend.py in split(self, prng, num)
    122   def split(self, prng, num=2):
--> 123     return backend()['random_split'](prng, num)
    124 

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/random.py in split(key, num)
    243   """
--> 244   return _split(key, num)
    245 

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    148     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    150     return tree_unflatten(out_tree(), out)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    604     tracers = map(top_trace.full_raise, args)
--> 605     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    606   return apply_todos(env_trace_todo(), outs)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    125     fun, aux = partial_eval(f, self, in_pvs)
--> 126     out_flat = call_primitive.bind(fun, *in_consts, **params)
    127     out_pvs, jaxpr, env = aux()

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    601     with new_sublevel():
--> 602       outs = primitive.impl(f, *args, **params)
    603   else:

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    441   backend = params['backend']
--> 442   compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
    443   try:

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    222     else:
--> 223       ans = call(fun, *args)
    224       cache[key] = (ans, fun.stores)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
    458   with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459     jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    460     assert not env  # no subtraces here

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    151 
--> 152     ans = self.f(*args, **dict(self.params, **kwargs))
    153     del args

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/random.py in _split(key, num)
    248   counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
--> 249   return lax.reshape(threefry_2x32(key, counts), (num, 2))
    250 

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    148     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    150     return tree_unflatten(out_tree(), out)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    604     tracers = map(top_trace.full_raise, args)
--> 605     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    606   return apply_todos(env_trace_todo(), outs)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    125     fun, aux = partial_eval(f, self, in_pvs)
--> 126     out_flat = call_primitive.bind(fun, *in_consts, **params)
    127     out_pvs, jaxpr, env = aux()

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    604     tracers = map(top_trace.full_raise, args)
--> 605     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    606   return apply_todos(env_trace_todo(), outs)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    125     fun, aux = partial_eval(f, self, in_pvs)
--> 126     out_flat = call_primitive.bind(fun, *in_consts, **params)
    127     out_pvs, jaxpr, env = aux()

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    601     with new_sublevel():
--> 602       outs = primitive.impl(f, *args, **params)
    603   else:

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    441   backend = params['backend']
--> 442   compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
    443   try:

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    222     else:
--> 223       ans = call(fun, *args)
    224       cache[key] = (ans, fun.stores)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
    458   with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459     jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    460     assert not env  # no subtraces here

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    151 
--> 152     ans = self.f(*args, **dict(self.params, **kwargs))
    153     del args

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/random.py in threefry_2x32(keypair, count)
    215   """
--> 216   key1, key2 = keypair
    217   if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:

ValueError: not enough values to unpack (expected 2, got 1)

During handling of the above exception, another exception occurred:

LayerError                                Traceback (most recent call last)
<ipython-input-30-58abfd2a9337> in <module>
      1 # Sample from the Reformer language model, given a prefix.
----> 2 samples = sample(length=128, prompt="There was a time when")
      3 for ids in samples:
      4   print(TOKENIZER.DecodeIds(ids.tolist()))

<ipython-input-29-f9fee3fa3424> in sample(length, prompt)
     19         model_weights,
     20         cur_state,
---> 21         rngs)
     22 
     23     if prompt is not None and iteration < prompt.shape[1]:

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    147     _check_args(args_flat)
    148     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    150     return tree_unflatten(out_tree(), out)
    151 

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    600   if top_trace is None:
    601     with new_sublevel():
--> 602       outs = primitive.impl(f, *args, **params)
    603   else:
    604     tracers = map(top_trace.full_raise, args)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    440   device = params['device']
    441   backend = params['backend']
--> 442   compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
    443   try:
    444     return compiled_fun(*args)

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    221       fun.populate_stores(stores)
    222     else:
--> 223       ans = call(fun, *args)
    224       cache[key] = (ans, fun.stores)
    225     return ans

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
    457   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    458   with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459     jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    460     assert not env  # no subtraces here
    461     del master, env

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     gen = None
    151 
--> 152     ans = self.f(*args, **dict(self.params, **kwargs))
    153     del args
    154     while stack:

~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/base.py in _forward_internal(self, x, weights, state, rng)
    460       name, trace = self.__class__.__name__, _short_traceback()
    461       raise LayerError(name, '_forward_internal',
--> 462                        self._caller, signature(x), trace)
    463 
    464   def _forward_abstract(self, input_signature):

LayerError: Exception passing through layer Serial (in _forward_internal):
  layer created in file [...]/models/reformer/reformer.py, line 612
  layer input shapes: ShapeDtype{shape:(1, 1, 1), dtype:int32}

  File [...]/trax/layers/combinators.py, line 60, in forward_with_state
    rngs = _pop_rng_and_split(kwargs, self._n_layers)

  File [...]/trax/layers/combinators.py, line 689, in _pop_rng_and_split
    return math.random.split(rng, n_copies)

  File [...]/trax/math/backend.py, line 123, in split
    return backend()['random_split'](prng, num)

  File [...]/site-packages/jax/random.py, line 244, in split
    return _split(key, num)

  File [...]/site-packages/jax/api.py, line 149, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)

  File [...]/site-packages/jax/core.py, line 605, in call_bind
    outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))

  File [...]/jax/interpreters/partial_eval.py, line 126, in process_call
    out_flat = call_primitive.bind(fun, *in_consts, **params)

  File [...]/site-packages/jax/core.py, line 602, in call_bind
    outs = primitive.impl(f, *args, **params)

  File [...]/jax/interpreters/xla.py, line 442, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))

  File [...]/site-packages/jax/linear_util.py, line 223, in memoized_fun
    ans = call(fun, *args)

  File [...]/jax/interpreters/xla.py, line 459, in _xla_callable
    jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 152, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/random.py, line 249, in _split
    return lax.reshape(threefry_2x32(key, counts), (num, 2))

  File [...]/site-packages/jax/api.py, line 149, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)

  File [...]/site-packages/jax/core.py, line 605, in call_bind
    outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))

  File [...]/jax/interpreters/partial_eval.py, line 126, in process_call
    out_flat = call_primitive.bind(fun, *in_consts, **params)

  File [...]/site-packages/jax/core.py, line 605, in call_bind
    outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))

  File [...]/jax/interpreters/partial_eval.py, line 126, in process_call
    out_flat = call_primitive.bind(fun, *in_consts, **params)

  File [...]/site-packages/jax/core.py, line 602, in call_bind
    outs = primitive.impl(f, *args, **params)

  File [...]/jax/interpreters/xla.py, line 442, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))

  File [...]/site-packages/jax/linear_util.py, line 223, in memoized_fun
    ans = call(fun, *args)

  File [...]/jax/interpreters/xla.py, line 459, in _xla_callable
    jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 152, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/random.py, line 216, in threefry_2x32
    key1, key2 = keypair

ValueError: not enough values to unpack (expected 2, got 1)

Desired behavior for _set_rng_recursive?

Description

I think there might be a bug in sublayer rng setting. In the layers/base.py file, the current code looks like:

trax/trax/layers/base.py

Lines 487 to 495 in 0a2be4a

# pylint: disable=protected-access
def _set_rng_recursive(self, rng):
"""Sets the rng (JAX PRNG key) for this layer and sublayers, recursively."""
self._rng = rng
sublayers = self.sublayers
if sublayers:
rngs = math.random.split(rng, len(sublayers))
for sublayer, rng in zip(sublayers, rngs):
sublayer._rng = rng

Shouldn't it be:

      sublayer._set_rng_recursive(rng)

instead of

      sublayer._rng = rng

I am not sure if this is the desired behavior, for how base Layers should be used, but see below for a minimum example of the current behavior and what the behavior becomes by changing that line.

Environment information

N/A

For bugs: reproduction and error logs

# MINIMUM WORKING EXAMPLE
from jax import numpy as np
from jax import random as jax_random
from trax import layers as tl

ser1 = tl.Serial([
    tl.Dense(2),
    tl.Dense(2)
    ])

ser2 = tl.Serial([
    tl.Dense(2),
    tl.Dense(2)
    ])

double_ser = tl.Serial(ser1,ser2)
rng = jax_random.PRNGKey(0)
rng, subkey = jax_random.split(rng)
weights, state = double_ser.init(np.zeros([1,2]), rng=subkey)
print(weights[0][0])
print(weights[0][1])
print(weights[1][0])
print(weights[1][1])

WITHOUT CHANGE (current behavior, all weights are the same):

(DeviceArray([[-0.35201126,  0.34358203],
             [ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))
(DeviceArray([[-0.35201126,  0.34358203],
             [ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))
(DeviceArray([[-0.35201126,  0.34358203],
             [ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))
(DeviceArray([[-0.35201126,  0.34358203],
             [ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))

WITH CHANGE (all weights are different):

(DeviceArray([[ 0.38822186,  0.5617548 ],
             [-0.3487    , -0.47204715]], dtype=float32), DeviceArray([-3.6210415e-07,  2.5783100e-07], dtype=float32))
(DeviceArray([[ 1.195412  , -0.91699356],
             [-0.75880295,  0.7693857 ]], dtype=float32), DeviceArray([3.0248168e-07, 1.8994491e-07], dtype=float32))
(DeviceArray([[-1.1131637 ,  1.1483511 ],
             [ 0.5354116 ,  0.78174126]], dtype=float32), DeviceArray([-8.4182403e-07, -1.8476302e-06], dtype=float32))
(DeviceArray([[-0.39427942, -0.25487363],
             [-0.61524516, -0.75742614]], dtype=float32), DeviceArray([-6.4497357e-07,  9.1538845e-07], dtype=float32))

Trax advantage vs other framework ?

Description

Dear people,

What would be the advantage of trax vs tf or Pytorch ?

Best,
T.C

Environment information

OS: <your answer here>

$ pip freeze | grep tensor
# your output here

$ pip freeze | grep jax
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...

Tensor2Tensor Transformer is not Trax Transormer

Description

Hello. I've been playing around with both T2T and Trax libraries for a while. Since Trax has several bugs during inference, I've decided to switch to T2T. However, it seems to me that Transformer in Tensor2Tensor is not the same as in Trax.

In Tensor2Tensor I create my Transformer model this way:

hparams_my = {
    'batch_size': 128,
    'batch_shuffle_size': 128,
    'use_fixed_batch_size': True,
    'num_hidden_layers': 1,
    'max_input_seq_length': 252,
    'max_target_seq_length': 252,
    'max_length': 252,
    'symbol_modality_num_shards': 1,
    'filter_size': 2048,
    'dropout': 0.1
}

In Trax:

Transformer(input_vocab_size=127,
                output_vocab_size=127,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=1,
                n_decoder_layers=1,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train',
                ff_activation=tl.Relu):

After I run training with T2T, I get this message:
(btw, 2 times)

INFO:tensorflow:Trainable Variables Total size: 7433728
INFO:tensorflow:Trainable Variables Total size: 7433728

Whereas in Trax I after I call trainer.print_n_weights() I get

Step      0: Total number of trainable weights: 7614591

I would like to notice, that when I train my Transformer model in Trax, I reach convergence almost immediately (considering the nature of the task - just simple sequence copying with little changes), while with T2T I reach some loss values like 3-4 and no convergence at all.

Could anybody tell me what do I have to do? It seems like a common problem with T2T Transformer convergence, but I want to emphasise that in Trax it is another Transformer...

[Question] Mix up indent amounts.

Hi, first of all thank everyone for this great project.

I just want to point out that currently indentation by 4 spaces and 2 spaces are used freely, even in the same file (trainer.py).
I know code style is a sensitive topic so I just want to ask if it is possible to standardize this (4 spaces), or there are some good reasons not to that I am not aware of?

Thanks.

beam_search.Search() in a single-GPU environment

Description

Training a Transformer converges.

Then beam_search fails though. When n_devices == 1 some reshapes crash in decode().

Environment information

OS: 
ubuntu 18.04 

CUDA 10.1
1 GPU environment

$ pip freeze | grep tensor
mesh-tensorflow==0.1.11
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0


$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39


$ python -V
Python 3.6.9

For bugs: reproduction and error logs

Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 442, in pure_fn
x, weights=weights, state=state, rng=rng)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 220, in forward_with_state
return self.forward(inputs, weights), state
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File "/root/.local/lib/python3.6/site-packages/trax/layers/attention.py", line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 959, in _reshape_method
return _reshape(a, newshape, order=order)
File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 938, in _reshape
return lax.reshape(a, computed_newshape, None)
File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 640, in reshape
old_sizes=onp.shape(operand))
File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 480, in _forward_abstract
input_signature, weight_signature, self.state, rng)
File "/root/.local/lib/python3.6/site-packages/trax/math/jax.py", line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 273, in abstract_eval_fun
instantiate=True)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 451, in pure_fn
self._caller, signature(x), trace)
trax.layers.base.LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state

File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays

File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)

File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)

File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))

File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)

File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)

File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)

File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))

TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 485, in _forward_abstract
trace)
trax.layers.base.LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file [...]/trax/layers/combinators.py, line 468
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)

File [...]/site-packages/jax/api.py, line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun
instantiate=True)

File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)

File [...]/trax/layers/combinators.py, line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state

File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays

File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)

File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)

File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))

File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)

File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)

File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)

File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))

TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/layers/combinators.py, line 470
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file [...]/trax/layers/combinators.py, line 468
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)

File [...]/site-packages/jax/api.py, line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun
instantiate=True)

File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)

File [...]/trax/layers/combinators.py, line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state

File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays

File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)

File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)

File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))

File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)

File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)

File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)

File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))

TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "math_trax.py", line 565, in
seqs, scores = beam_decoder.decode(inputs=batch, batch_size=iBatch_size)#, )targets_prefix=prefix_for_bs,
File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 602, in decode
dummy=np.zeros(n_devices))
File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 146, in f_jitted
name=flat_fun.name)
File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 642, in call_bind
outs = primitive.impl(f, *args, **params)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 448, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 220, in memoized_fun
ans = call(fun, *args)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 465, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 535, in _unreplicated_beam_search
self._get_initial_state(inputs, targets_prefix, batch_size),
File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 490, in _get_initial_state
_, initial_state = self.model(mode='predict').init(signature)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 301
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(2, 1), dtype:int32})

File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)

LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/layers/combinators.py, line 470
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file [...]/trax/layers/combinators.py, line 468
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)

File [...]/site-packages/jax/api.py, line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun
instantiate=True)

File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)

File [...]/trax/layers/combinators.py, line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state

File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays

File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)

File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)

File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))

File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)

File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)

File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)

File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))

TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

REMARK:
1 is n_devices, 2 is batch size, 30 is max_len

# Steps to reproduce:
I tried to force your machine_translation.ipynb in colab to use the GPU but didnt succeed. But maybe for you it's the fastest to check what happens if only 1 GPU as the colab in itsef runs smoothly (on a TPU).
# Error logs:
...

Unable to pip install trax on local computer

Description

I was trying to pip install trax on my local computer but I am unable to complete the installation due to some errors.

Environment information

OS: Windows 10 
Version 1909 
OS Build 18363.628

$ pip freeze | grep tensor
-

$ pip freeze | grep jax
-

$ python -V
Python 3.7.4

For bugs: reproduction and error logs

# Steps to reproduce:
...
Open CMD and run pip install trax
# Error logs:
...
Collecting trax
  Using cached trax-1.2.2-py2.py3-none-any.whl (311 kB)
Collecting jax
  Using cached jax-0.1.58.tar.gz (262 kB)
Requirement already satisfied: numpy in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from tr
ax) (1.18.1)
Requirement already satisfied: scipy in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from tr
ax) (1.4.1)
Collecting gin-config
  Using cached gin_config-0.3.0-py3-none-any.whl (44 kB)
Collecting funcsigs
  Using cached funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)
Requirement already satisfied: absl-py in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from
trax) (0.9.0)
Collecting tensorflow-datasets
  Using cached tensorflow_datasets-2.0.0-py3-none-any.whl (3.1 MB)
Collecting tensor2tensor
  Using cached tensor2tensor-1.15.4-py2.py3-none-any.whl (1.4 MB)
Requirement already satisfied: six in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from trax
) (1.12.0)
Collecting gym
  Using cached gym-0.15.6.tar.gz (1.6 MB)
ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none)
ERROR: No matching distribution found for jaxlib (from trax)

Reformer model with long sequence throws an error

Im trying to train the basic Reformer and not the RefomerLM on long sequence of text based on the language generation example. Simply by replacing the RefomerLM class with the reformer and remove the mask, but feeding in the entire crime and punishment book, throws the following error:

TypeError: requesting more random bits than a single call provides.

everything works fine if I cut down the input to smaller sequences. The example can be seen in the following notebook:
https://colab.research.google.com/drive/1C9KOHOfuVhoOqzRKx_rRaeOV3jPvuZ_L

output_dir = None breaks Trainer

Description

When setting the output_dir argument to None when creating a trax.supervised.Trainer this leads to the reset method not being called (Line 216/217 in in trainer_lib: if output_dir is not None: self.reset(output_dir)), which breaks the Trainer because the reset method sets the train_stream among other things and a model can't function without the train stream. Surely you would still want to have a train stream even if not writing the model to an output dir? Is there a reason for the if condition in line 216?

For bugs: reproduction and error logs

# Steps to reproduce:
Create a trainer with any model without setting an output_dir.
# Error logs:
...

program using supervised.train does not end

Description

I'm running command:

python -m trax.trainer --config_file=$PWD/trax/configs/mlp_mnist.gin

It trains the model, prints Finished training. and then hangs forever. This process is not using CPU but it does not exit either. It never returns to the shell and I have to terminate it using Ctrl-C

Environment information

Trax: 1.2.2

OS: Ubuntu 18.04

$ pip freeze | grep tensor

mesh-tensorflow==0.1.4
neptune-tensorboard==0.3.8
tensor2tensor==1.14.1
tensorboard==1.15.0
tensorflow==1.15.0
tensorflow-datasets==1.3.0
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.1
tensorflow-probability==0.8.0


$ pip freeze | grep jax

jax==0.1.51
jaxlib==0.1.32

$ python -V
Python 3.7.3

Steps to reproduce:

python -m trax.trainer --config_file=$PWD/trax/configs/mlp_mnist.gin

Error logs:

I0217 13:01:30.008547 140681594124096 trainer_lib.py:751] Step   2000: Ran 200 train steps in 4.19 secs
Step   2000: Ran 200 train steps in 4.19 secs
I0217 13:01:30.012206 140681594124096 trainer_lib.py:751] Step   2000: Evaluation
Step   2000: Evaluation
I0217 13:01:30.073523 140681594124096 trainer_lib.py:751] Step   2000: train                   accuracy |  0.99804688
Step   2000: train                   accuracy |  0.99804688
I0217 13:01:30.074985 140681594124096 trainer_lib.py:751] Step   2000: train                       loss |  0.01065689
Step   2000: train                       loss |  0.01065689
I0217 13:01:30.076281 140681594124096 trainer_lib.py:751] Step   2000: train         neg_log_perplexity |  0.01065689
Step   2000: train         neg_log_perplexity |  0.01065689
I0217 13:01:30.077118 140681594124096 trainer_lib.py:751] Step   2000: train weights_per_batch_per_core |  256.00000000
Step   2000: train weights_per_batch_per_core |  256.00000000
I0217 13:01:30.423881 140681594124096 trainer_lib.py:751] Step   2000: eval                    accuracy |  0.96406251
Step   2000: eval                    accuracy |  0.96406251
I0217 13:01:30.424737 140681594124096 trainer_lib.py:751] Step   2000: eval                        loss |  0.62180674
Step   2000: eval                        loss |  0.62180674
I0217 13:01:30.426048 140681594124096 trainer_lib.py:751] Step   2000: eval          neg_log_perplexity |  0.62180674
Step   2000: eval          neg_log_perplexity |  0.62180674
I0217 13:01:30.426502 140681594124096 trainer_lib.py:751] Step   2000: eval  weights_per_batch_per_core |  256.00000000
Step   2000: eval  weights_per_batch_per_core |  256.00000000
I0217 13:01:30.427090 140681594124096 trainer_lib.py:751] Step   2000: Finished evaluation
Step   2000: Finished evaluation
I0217 13:01:30.445652 140681594124096 trainer_lib.py:751] Model saved to /home/pawel/trax/MLP_mnist_20200217_1300/model.pkl
I0217 13:01:30.446032 140681594124096 trainer_lib.py:751] Step   2000: Training done
Step   2000: Training done
I0217 13:01:30.446371 140681594124096 trainer_lib.py:751] Finished training.
Finished training.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.