Git Product home page Git Product logo

t5x's Introduction

T5X

Go to T5X ReadTheDocs Documentation Page.

T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.

It is essentially a new and improved implementation of the T5 codebase (based on Mesh TensorFlow) in JAX and Flax. To learn more, see the T5X Paper.

Below is a quick start guide for training models with TPUs on Google Cloud. For additional tutorials and background, see the complete documentation.

Quickstart (Recommended)

T5X can be run with XManager on Vertex AI. Vertex AI is a platform for training that creates TPU instances and runs code on the TPUs. Vertex AI will also shut down the TPUs when the jobs terminate. This is signifcantly easier than managing GCE VMs and TPU VM instances.

  1. Follow the pre-requisites and directions to install XManager.

  2. Request TPU quota as required. GCP projects come with 8 cores by default, which is enough to run one training experiment on a single TPU host. If you want to run multi-host training or run multiple trials in parallel, you will need more quota. Navigate to Quotas.

The quota you want is:

  • Service: Vertex AI API
  • Dimensions (location): us-central1
  • If you want to run single-host experiments:
    • Custom model training TPU V2 cores per region
    • Custom model training TPU V3 cores per region
  • If you want to run multi-host experiments:
    • Custom model training TPU V2 pod cores per region
    • Custom model training TPU V3 pod cores per region

TIP: You won't be able to run single-host experiments with multi-host quota. (i.e. you can't run tpu_v2=8 using TPU V2 pod)

  1. Launch the xmanager script located at t5x/scripts/xm_launch.py.

As a running example, we use the WMT14 En-De translation which is described in more detail in the Examples section below.

export GOOGLE_CLOUD_BUCKET_NAME=...
export TFDS_DATA_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data
export MODEL_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d)

# Pre-download dataset in multi-host experiments.
tfds build wmt_t2t_translate --data_dir=$TFDS_DATA_DIR

git clone https://github.com/google-research/t5x
cd ./t5x/

python3 ./t5x/scripts/xm_launch.py \
  --gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin \
  --model_dir=$MODEL_DIR \
  --tfds_data_dir=$TFDS_DATA_DIR

Check gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/ for the output artifacts, which can be read by TensorBoard.

GPU Usage

Note: NVIDIA has released an updated version of this repository with H100 FP8 support and broad GPU performance improvements. Please visit the NVIDIA Rosetta repository for more details and usage instructions.

T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at t5x/contrib/gpu. The t5x/contrib/gpu/scripts_gpu folder contains example scripts for pretraining T5X on The Pile and for finetuning on SQuAD and MNLI. These scripts and associated gin configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the NVIDIA Rosetta repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements.

Installation

Note that all the commands in this document should be run in the commandline of the TPU VM instance unless otherwise stated.

  1. Follow the instructions to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU API.

    Note: T5X also works with GPU, please follow instructions in t5x/contrib/gpu if you'd like to use GPU version.

  2. Create a Cloud TPU VM instance following this instruction. We recommend that you develop your workflow in a single v3-8 TPU (i.e., --accelerator-type=v3-8) and scale up to pod slices once the pipeline is ready. In this README, we focus on using a single v3-8 TPU. See here to learn more about TPU architectures.

  3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM. You can install packages, run your code run, etc. in the host machine. Once the TPU instance is created, ssh into it with

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}

    where TPU_NAME and ZONE are the name and the zone used in step 2.

  4. Install T5X and the dependencies.

    git clone --branch=main https://github.com/google-research/t5x
    cd t5x
    
    python3 -m pip install -e '.[tpu]' -f \
      https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. Create Google Cloud Storage (GCS) bucket to store the dataset and model checkpoints. To create a GCS bucket, see these instructions.

  6. (optional) If you prefer working with Jupyter/Colab style environment you can setup a custom Colab runtime by following steps from t5x/notebooks.

Example: English to German translation

As a running example, we use the WMT14 En-De translation. The raw dataset is available in TensorFlow Datasets as "wmt_t2t_translate".

T5 casts the translation task such as the following

{'en': 'That is good.', 'de': 'Das ist gut.'}

to the form called "text-to-text":

{'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'}

This formulation allows many different classes of language tasks to be expressed in a uniform manner and a single encoder-decoder architecture can handle them without any task-specific parameters. For more detail, refer to the T5 paper (Raffel et al. 2019).

For a scalable data pipeline and an evaluation framework, we use SeqIO, which was factored out of the T5 library. A seqio.Task packages together the raw dataset, vocabulary, preprocessing such as tokenization and evaluation metrics such as BLEU and provides a tf.data instance.

The T5 library provides a number of seqio.Tasks that were used in the T5 paper. In this example, we use wmt_t2t_ende_v003.

Before training or fine-tuning you need to download ["wmt_t2t_translate"] (https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate) dataset first.

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."

# Make sure that dataset package is up-to-date.
python3 -m pip install --upgrade tfds-nightly

# Pre-download dataset.
tfds build wmt_t2t_translate ${TFDS_DATA_DIR}

Training

To run a training job, we use the t5x/train.py script.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.
TFDS_DATA_DIR="..."

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

The configuration for this training run is defined in the Gin file base_wmt_from_scratch.gin. Gin-config is a library to handle configurations based on dependency injection. Among many benefits, Gin allows users to pass custom components such as a custom model to the T5X library without having to modify the core library. The custom components section shows how this is done.

While the core library is independent of Gin, it is central to the examples we provide. Therefore, we provide a short introduction to Gin in the context of T5X. All the configurations are written to a file "config.gin" in MODEL_DIR. This makes debugging as well as reproducing the experiment much easier.

In addition to the config.json, model-info.txt file summarizes the model parameters (shape, names of the axes, partitioning info) as well as the optimizer states.

TensorBoard

To monitor the training in TensorBoard, it is much easier (due to authentification issues) to launch the TensorBoard on your own machine and not in the TPU VM. So in the commandline where you ssh'ed into the TPU VM, launch the TensorBoard with the logdir pointing to the MODEL_DIR.

# NB: run this on your machine not TPU VM!
MODEL_DIR="..."  # Copy from the TPU VM.
tensorboard --logdir=${MODEL_DIR}

Or you can launch the TensorBoard inside a Colab. In a Colab cell, run

from google.colab import auth
auth.authenticate_user()

to authorize the Colab to access the GCS bucket and launch the TensorBoard.

%load_ext tensorboard
model_dir = "..."  # Copy from the TPU VM.
%tensorboard --logdir=model_dir

Fine-tuning

We can leverage the benefits of self-supervised pre-training by initializing from one of our pre-trained models. Here we use the T5.1.1 Base checkpoint.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note: when supplying a string, dict, list, tuple value, or a bash variable via a flag, you must put it in quotes. In the case of strings, it requires escaped quotes (\"<string>\"). For example: --gin.utils.DatasetConfig.split=\"validation\" or --gin.MODEL_DIR=\"${MODEL_DIR}\".

Gin makes it easy to change a number of configurations. For example, you can change the partitioning.PjitPartitioner.num_partitions (overriding the value in base_wmt_from_scratch.gin) to chanage the parallelism strategy and pass it as a commandline arg.

--gin.partitioning.PjitPartitioner.num_partitions=8

Evaluation

To run the offline (i.e. without training) evaluation, you can use t5x/eval.py script.

EVAL_OUTPUT_DIR="..."  # directory to write eval output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/eval.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin" \
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
  --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Inference

To run inference, you can use t5x/infer.py script. Here we use the same seqio.Task, but for inference we do not use the targets features other than logging them alongside the prediction in a JSON file.

INFER_OUTPUT_DIR="..."  # directory to write infer output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/infer.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin" \
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
  --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Exporting as TensorFlow Saved Model

Pretrained model can be exported as TensorFlow Saved Model, and deployed to Vertex AI Prediction service using [Optimized TensorFlow Runtime] (https://cloud.google.com/vertex-ai/docs/predictions/optimized-tensorflow-runtime). Please note that exported model won't work on OSS based TensorFlow Model Server.

T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
CHECKPOINT_PATH="..."

BATCH_SIZE=None
BEAM_SIZE=1

# Use 'bfloat16' if you plan to run exported model on NVIDIA A100 or newer GPUs,
# for other GPUs use 'float32'.
ACTIVATION_DTYPE=bfloat16

# Version numbers must be numeric. We generate one based on datetime.
VERSION=$(date +%Y%m%d%H%M%S)

NAME=t5x_base_${ACTIVATION_DTYPE}  # Model name.

# Path to export model to. Note that export script is going to add _cpu suffix
# after model name.
OUTPUT=${CHECKPOINT_PATH}/saved_model.${NAME}/${VERSION}

declare -a ARGS=(
--gin_file=t5x/examples/t5/t5_1_1/base.gin
--gin_file=t5x/t5x/configs/runs/export.gin
--gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}"
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\"
--gin.MODEL_NAME=\"/ml/${USER}/t5x_base\"
--gin.MODEL_OUTPUT_DIR=\"${OUTPUT}\"
--gin.BEAM_SIZE=${BEAM_SIZE}
--gin.BATCH_SIZE=${BATCH_SIZE}
--gin.export_lib.save.partitioner=None
--gin.export_lib.save.warmup_examples="['hello world']"
--gin.export_lib.ExportableModule.use_batch_function=False
--gin.export_lib.ExportableModule.use_gpu=False
--gin.export_lib.ExportableModule.jit_compile=False
--gin.ACTIVATION_DTYPE=\"${ACTIVATION_DTYPE}\"
--gin.network.T5Config.dtype=\"${ACTIVATION_DTYPE}\"
--gin.utils.RestoreCheckpointConfig.dtype=\"${ACTIVATION_DTYPE}\"
--gin.DROPOUT_RATE=0.0
)

(python3 ${T5X_DIR}/t5x/export.py "${ARGS[@]}")

For detailed arguments definition refer to [export.gin] (t5x/configs/runs/export.gin).

You can run XL and smaller models on NVIDIA A100 40GB, and XXL models on NVIDIA A100 80GB.

Custom components

The translation example uses the encoder-decoder model that T5X provides as well as the dataset from the T5 library. This section shows how you can use your own dataset and a model and pass via Gin.

Example: custom dataset in a user directory

For this example, we have the following directory structure with ${HOME}/dir1/user_dir representing a user directory with custom components.

${HOME}
└── dir1
    └── user_dir
        ├── t5_1_1_base_de_en.gin
        └── tasks.py

As an example, let's define a new dataset. Here we use the same Translation dataset but we define the translation task in the opposite direction, i.e., German to English intead of English to German. We define this task in tasks.py

# ${HOME}/dir1/user_dir/tasks.py

import functools
import seqio
import tensorflow_datasets as tfds
from t5.evaluation import metrics
from t5.data import preprocessors

vocabulary = seqio.SentencePieceVocabulary(
    'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model', extra_ids=100)
output_features = {
    'inputs': seqio.Feature(vocabulary=vocabulary),
    'targets': seqio.Feature(vocabulary=vocabulary)
}

seqio.TaskRegistry.add(
    'wmt_t2t_de_en_v003',
    source=seqio.TfdsDataSource(tfds_name='wmt_t2t_translate/de-en:1.0.0'),
    preprocessors=[
        functools.partial(
            preprocessors.translate,
            source_language='de', target_language='en'),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
    ],
    metric_fns=[metrics.bleu],
    output_features=output_features)

In the Gin file, most of the settings are equivalent to those used in the En->De example. So we include the Gin file from that example. To use "wmt_t2t_de_en_v003" task we just defined, we need to import the task module "tasks.py". Note that we use a relative path defined with respect to the user directory. This will be specified as a flag.

# ${HOME}/dir1/user_dir/t5_1_1_base_de_en.gin
from __gin__ import dynamic_registration
import tasks  # This imports the task defined in dir1/user_dir/tasks.py.

include "t5x-tmp/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin"
MIXTURE_OR_TASK_NAME = "wmt_t2t_de_en_v003"

Finally, we launch training passing the user directory as a flag gin_search_paths such that the Gin file and python modules can be specified with relative paths.

PROJECT_DIR=${HOME}"/dir1/user_dir"
T5X_DIR="..."  # directory where the t5x is cloned.
TFDS_DATA_DIR="..."
MODEL_DIR="..."
export PYTHONPATH=${PROJECT_DIR}

python3 ${T5X_DIR}/t5x/train.py \
  --gin_search_paths=${PROJECT_DIR} \
  --gin_file="t5_1_1_base_de_en.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Checkpoints

Native Checkpoints

We have released the checkpoints of many of the original T5 models and their variants a native T5X format for maximal efficiency. See the complete list including the matching Gin configuration files.

These are converted from the public Mesh TensorFlow checkpoints .

Compatibility with the Mesh TensorFlow checkpoints

The Mesh TensorFlow checkpoints trained using the T5 library can be directly loaded into T5X. For example, we can rerun the fine-tuning example initializing from the MTF checkpoint by changing the INIT_CHECKPOINT Gin macro.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \
  --gin.INIT_CHECKPOINT=\"gs://t5-data/pretrained_models/t5.1.1.base/model.ckpt-1000000\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note that restoring directly from the Mesh TensorFlow checkpoints can be inefficient if heavy model parallelism is used for large models. This is because each host loads the entire copy of the model first and then keep only the relevant slices dictated by the model parallelism specification. If you have Mesh TensorFlow checkpoints that you run often, we recommend converting the checkpoints to T5X native format using the convert_tf_checkpoint script.

Citing T5X

Please use the following bibtex entry to cite T5X.

@article{roberts2022t5x,
  url = {https://arxiv.org/abs/2203.17189},
  author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra, Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester, Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy, Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel, Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov, Alexander and Newlan, Joshua and Gesmundo, Andrea},
  title = {Scaling Up Models and Data with $\texttt{t5x}$ and $\texttt{seqio}$},
  journal={arXiv preprint arXiv:2203.17189},
  year = {2022},
}

Note

This is not an officially supported Google product

t5x's People

Contributors

0x0539 avatar adarob avatar afrozenator avatar andrewluchen avatar blester125 avatar cghawthorne avatar chiamp avatar cpgaffney1 avatar ebrevdo avatar fehiepsi avatar gauravmishra avatar gshennvm avatar hawkinsp avatar hwchung27 avatar iansimon avatar jacobaustin123 avatar jekbradbury avatar kehang avatar kkenealy avatar levskaya avatar liangyaning33 avatar marvin182 avatar maxwillzq avatar nconstant-google avatar qstanczyk avatar sahiljain314 avatar sauravmaheshkar avatar texasmichelle avatar voutcn avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

t5x's Issues

Request for more robust release management

Would it be possible to establish some release management process? E.g. tagging and release notes.

Some commits and rollbacks, like the recent rollback of 'Enable GDA by default in T5X training and checkpointing' introduce breaking changes.

Training freezes when trying to evaluate previously saved checkpoint

I'm fine-tuning an mT5 small checkpoint on a custom mixture. For debugging, I've set eval_period=25 and utils.SaveCheckpointConfig.period=50 in finetune.gin.
The training job freezes inside trainer.py whenever it tries to compute metrics on a saved checkpoint (see attached screenshot)

Screenshot 2022-03-14 11 19 52 PM

Support forward mode differentiation

Currently forward mode differentiation does not work as losses.py implements the cross_entropy_with_logits using jax.custom_vjp. If it was implemented with jax.custom_jvp one would get both forward and reverse mode supported. An example application of forward mode differentiation is inspecting the Hessian and the eigenvalues of a model.

Custom dataset gin file

I have followed this section to create my own dataset and run the train.py script. However, I get the following error:

ModuleNotFoundError: No module named 'tasks'
  In file "/root/dir1/user_dir/t5_1_1_base_custom_v1.gin", line 2
    import tasks

Any help would be appreciated!

Using Gradient Accumulation

Hey guys!

I am about to pretrain a monolingual model using T5X (thank you for this!).

The routine I'll be following is based on ByT5 paper. However, I currently have access to a smaller TPU (v3-8 core), so 220 would not fit into its memory. To accomplish similar results, I am thinking of using gradient accumulation, so I can emulate the same batch size used for pertaining the original ByT5.

I couldn't find any documentation about this, but looking through the code, I guess I would have to:

  1. Specify BATCH_SIZE = 1024 (220/1024)
  2. Specify trainer.num_microbatches = 16

With this, I hope to fit 64 examples in a step (does T5X auto distribute 8 samples per core here?), but update gradients every 16 steps, emulating a 1024 batch size.

The resulting gin would be something like:

include 't5x/examples/t5/byt5/small.gin'
include 't5x/configs/runs/pretrain.gin'

TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 189}
TRAIN_STEPS = 1_000_000
DROPOUT_RATE = 0.0
BATCH_SIZE = 1024

trainer.Trainer:
  num_microbatches = 16

Is that correct?

Documentation issues

@adarob

Thanks for making this available. Just a couple minor issues on the main page that could make stuff easier:

  • The referred script "t5x/examples/t5/t5_1_1/examples/base_wmt_train.gin" does not exist. I guess it is "t5x/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" that should be used.

  • Neither does the base finetuning script that is referred further down the page. There is a "t5x/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin" this load a small-model but tries to load a base checkpoint. "INITIAL_CHECKPOINT_PATH = "gs://t5-data//pretrained_models/t5x/t5_1_1_base/checkpoint_1000000"". Small typo in the checkpoint path though. The correct is "gs://t5-data//pretrained_models/t5x/t5_1_1_base/checkpoint_1000000".

I see that the tfds preparation instructions just has a note to @hwchung27. This would be really useful. I am particularly interested in the process of preparing a domain specific dataset (for instance another language) for further pretraining.

Outputting results from evaluation

I am following the examples for running eval-only. I add the following example metric to the task-file:

metric_fns=[metrics.accuracy],

The evaluation runs without any errors, and it is also writing the files config.gin and model-info-txtto output-directory. However, there does not seem to be any other results. Are there other steps that needs to be performed here?

`num_partitions` does not work for GPU

Hi thanks for the great work.
I were already carefully read the docs of the partitioning, but I am still confused about how it works and what did the partitioning rules means.
I tried to run the pertaining code on a single node with 8-A100 GPU. When I pretrain the T5 with the huggingface trainer and deepspeed Zero-2, it works well. However I tried to run the pretrain with the scripts provided in the examples with

partitioning.PjitPartitioner:
  num_partitions = 1
  logical_axis_rules= @partitioning.standard_logical_axis_rules()

partitioning.standard_logical_axis_rules:
  activation_partitioning_dims = 2
  parameter_partitioning_dims = 2

,

I get the following errors:

56   │ Traceback (most recent call last):
  57   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 196, in _run_module_as_main
  58   │     return _run_code(code, main_globals, None,
  59   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 86, in _run_code
  60   │     exec(code, run_globals)
  61   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 659, in <module>
  62   │     gin_utils.run(main)
  63   │   File "/mnt/cache/namco/t5x/t5x/gin_utils.py", line 105, in run
  64   │     app.run(
  65   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 312, in run
  66   │     _run_main(main, args)
  67   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
  68   │     sys.exit(main(argv))
  69   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 637, in main
  70   │     _main(argv)
  71   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 657, in _main
  72   │     train_using_gin()
  73   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1605, in gin_wrapper
  74   │     utils.augment_exception_message_and_reraise(e, err_str)
  75   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
  76   │     raise proxy.with_traceback(exception.__traceback__) from None
  77   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1582, in gin_wrapper
  78   │     return fn(*new_args, **new_kwargs)
  79   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 321, in train
  80   │     train_state = train_state_initializer.from_checkpoint_or_scratch(
  81   │   File "/mnt/cache/namco/t5x/t5x/utils.py", line 523, in from_checkpoint_or_scratch
  82   │     or self.from_scratch(init_rng))
  83   │   File "/mnt/cache/namco/t5x/t5x/utils.py", line 395, in from_scratch
  84   │     return p_initialize_train_state_fn(init_rng)
  85   │   File "/mnt/cache/namco/t5x/t5x/partitioning.py", line 729, in __call__
  86   │     return self._pjitted_fn(*args)
  87   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 267, in wrapped
  88   │     args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
  89   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 246, in infer_params
  90   │     jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
  91   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
  92   │     ans = call(fun, *args)
  93   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 411, in _pjit_jaxpr
  94   │     _check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape,
  95   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 588, in _check_shapes_against_resources
  96   │     raise ValueError(f"One of {what} was given the resource assignment "
  97   │ ValueError: One of pjit outputs was given the resource assignment of PartitionSpec('model', None), which implies that the size of its dimension 0 should be divisib
       │ le by 8, but it is equal to 12
  98   │   In call to configurable 'train' (<function train at 0x7f598523c790>)

Could you please help me to fix this error?

Error after update in train.py

@adarob @t5-copybara

I have started getting this error:
'''
File "../../t5x/t5x/train.py", line 254, in train
train_iter = clu.data.TfDatasetIterator(train_iter, checkpoint=True)
TypeError: init() got an unexpected keyword argument 'checkpoint'
'''

I noticed this line was updated yesterday.

Encoder-only T5

Hi,
Thanks for the great work!

Just a quick question: I'm looking to replicate the results of the EncT5 paper and was wondering if this repo could be useful. Is there an easy way to extract (and finetune) an encoder-only T5 model using your implementation? If not, is this something on your roadmap?

Inference epoch

Why does inference happen in multiple epochs? The results of each epoch are saved in a separate file.

I0111 07:53:16.425345 139870442810432 infer.py:416] Running inference on 100 batches.
I0111 07:53:16.443239 139870442810432 utils.py:697] length of dataset = 800
I0111 07:53:16.448796 139870442810432 utils.py:723] The infer dataset is sharded into 1 shards with per-shard batch size of 8
I0111 07:53:16.507651 139870442810432 utils.py:733] Inference of batch [0 1 2 3 4 5 6 7] done
....
I0111 07:54:25.712660 139870442810432 utils.py:733] Inference of batch [792 793 794 795 796 797 798 799] done.
I0111 07:54:25.720639 139870442810432 utils.py:743] Inference of all batches done.
I0111 07:54:58.374256 139870442810432 infer.py:425] Epoch completed in 101.979152 seconds (7.844741 examples/sec).
I0111 07:54:58.427566 139870442810432 infer.py:443] Checkpoint written to temporary location in 0.052977 seconds.
I0111 07:54:58.427980 139851823490816 infer.py:382] Writing epoch 1 results to /home/torinaki/src/product-description-generation/infer/tmp-product_descriptions_dataset-00000-of-00001/product_descriptions_dataset-predict.jsonl-00000-of-00001-epoch00001
I0111 07:55:00.359092 139851823490816 infer.py:386] Writing completed in 1.931091 seconds (414.273574 examples/sec).
I0111 07:55:00.368155 139870442810432 infer.py:401] Starting epoch 2
I0111 07:55:00.404672 139870442810432 infer.py:416] Running inference on 100 batches.
I0111 07:55:00.425312 139870442810432 utils.py:697] length of dataset = 800
I0111 07:55:00.431032 139870442810432 utils.py:723] The infer dataset is sharded into 1 shards with per-shard batch size of 8
I0111 07:55:00.511902 139870442810432 utils.py:733] Inference of batch [0 1 2 3 4 5 6 7] done.
I0111 07:55:00.553246 139870442810432 utils.py:733] Inference of batch [ 8  9 10 11 12 13 14 15] done.
...
I0111 07:56:09.090829 139870442810432 utils.py:733] Inference of batch [792 793 794 795 796 797 798 799] done.
I0111 07:56:09.098065 139870442810432 utils.py:743] Inference of all batches done.
I0111 07:56:42.189918 139870442810432 infer.py:425] Epoch completed in 101.821549 seconds (7.856883 examples/sec).
I0111 07:56:42.248841 139870442810432 infer.py:443] Checkpoint written to temporary location in 0.058321 seconds.
I0111 07:56:42.249297 139851823490816 infer.py:382] Writing epoch 2 results to /home/torinaki/src/product-description-generation/infer/tmp-product_descriptions_dataset-00000-of-00001/product_descriptions_dataset-predict.jsonl-00000-of-00001-epoch00002
I0111 07:56:44.204791 139851823490816 infer.py:386] Writing completed in 1.955470 seconds (409.108840 examples/sec).
I0111 07:56:44.234261 139870442810432 infer.py:401] Starting epoch 3
...

ValueError: None values not supported

upon running a seqio mixture on mT5 and ByT5 i get and error stating: ValueError: None values not supported

I currently am using a seqio mixture that i define in my task.py file and use the default mt5 tokenizer gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model with extra_ids=0

This is how my task.py file looks

import functools
import seqio
import tensorflow as tf
import t5.data
from datasets import load_from_disk, load_dataset
from t5.data import postprocessors
from t5.data import preprocessors
from t5.evaluation import metrics
from seqio import FunctionDataSource, utils

TaskRegistry = seqio.TaskRegistry
vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)


DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=vocabulary, add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=vocabulary, add_eos=True)
}



def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_path=None):
    dataset = load_dataset(dataset_path, streaming=True, use_auth_token=True)
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            yield item[column]


def dataset_fn(split, shuffle_files, seed=None, dataset_path=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files, seed, dataset_path=dataset_path),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_path)
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}

# link to the mt5 sentencepiece tokenizer vocabulary
vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)

TaskRegistry.add(
    "hindi_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='StephennFernandes/ciil_mega_corpus_hindi'),
        splits=("train", "validation"),
        caching_permitted=False),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption, 
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"],"inputs": seqio.Feature(vocabulary=vocabulary,add_eos=True)},
    metric_fns=[]
)
### similar multiple tasks exist for multiple languages. ### 

seqio.MixtureRegistry.add(
  "ciil_mix_3",
  ["assamese_span_curruption", "bengali_span_curruption", 
  "bhisnupuriya_span_curruption", "bodo_span_curruption", 
  "divehi_span_curruption", "dogri_span_curruption", 
  "english_span_curruption", "gujarati_span_curruption",
  "hindi_span_curruption", "kannada_span_curruption", 
  "kashmiri_span_curruption", "konkani_span_curruption", 
  "maithili_span_curruption", "malayalam_span_curruption",
  "manipuri_span_curruption", "marathi_span_curruption",
  "nepali_span_curruption", "odia_span_curruption",
  "panjabi_span_curruption", "sanskrit_span_curruption",
  "tamil_span_curruption", "telugu_span_curruption",
   "urdu_span_curruption" ],
  default_rate=3
)

i use the ciil_mix_3 mixture in my .gin file
this is how my .gin file looks

from __gin__ import dynamic_registration
import t5.data.mixtures
import __main__ as train_script


include 't5x/examples/t5/mt5/base.gin'
include 't5x/configs/runs/pretrain.gin'

import task 

MIXTURE_OR_TASK_NAME = "ciil_mix_3"
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}
TRAIN_STEPS = 100000
DROPOUT_RATE = 0.0
BATCH_SIZE = 32


train_script.train:
  eval_period = 2000

The following is the entire stack track of the same:

Traceback (most recent call last):
  File "/home/stephen/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/stephen/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 748, in <module>
    gin_utils.run(main)
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/gin_utils.py", line 107, in run
    app.run(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 708, in main
    _main(argv)
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 744, in _main
    train_using_gin()
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 249, in train
    train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1366, in get_dataset
    return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1387, in get_dataset_inner
    ds = seqio.get_dataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1671, in get_dataset
    ds = mixture_or_task.get_dataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1457, in get_dataset
    datasets = [
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1458, in <listcomp>
    task.get_dataset(  # pylint:disable=g-complex-comprehension
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1209, in get_dataset
    ds = self.preprocess_postcache(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1044, in preprocess_postcache
    dataset = self._preprocess_dataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 965, in _preprocess_dataset
    dataset = prep_fn(dataset, **kwargs)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/preprocessors.py", line 83, in tokenize
    return utils.map_over_dataset(fn=tokenize_fn)(dataset)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/utils.py", line 778, in wrapped_fn
    return ds.map(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2050, in map
    return ParallelMapDataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 5284, in __init__
    self._map_func = structured_function.StructuredFunctionWrapper(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 271, in __init__
    self._function = fn_factory()
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2567, in get_concrete_function
    graph_function = self._get_concrete_function_garbage_collected(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2533, in _get_concrete_function_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 248, in wrapped_fn
    ret = wrapper_helper(*args)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 177, in wrapper_helper
    ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
    raise e.ag_error_metadata.to_exception(e)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 689, in wrapper
    return converted_call(f, args, kwargs, options=options)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/tmp/__autograph_generated_fileu9gu1w4n.py", line 8, in <lambda>
    tf__lam = lambda arg: ag__.with_function_scope(lambda lscope: ag__.converted_call(fn, (arg,) + tuple(args), dict(**kargs), lscope), 'lscope', ag__.STD)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/core/function_wrappers.py", line 113, in with_function_scope
    return thunk(scope)
  File "/tmp/__autograph_generated_fileu9gu1w4n.py", line 8, in <lambda>
    tf__lam = lambda arg: ag__.with_function_scope(lambda lscope: ag__.converted_call(fn, (arg,) + tuple(args), dict(**kargs), lscope), 'lscope', ag__.STD)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 352, in converted_call
    return converted_call(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/tmp/__autograph_generated_filezbhafqmt.py", line 113, in tf__tokenize_impl
    ag__.for_stmt(ag__.converted_call(ag__.ld(features).items, (), None, fscope), None, loop_body, get_state_4, set_state_4, (), {'iterate_names': '(k, v)'})
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 449, in for_stmt
    _py_for_stmt(iter_, extra_test, body, None, None)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 498, in _py_for_stmt
    body(target)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 464, in protected_body
    original_body(protected_iter)
  File "/tmp/__autograph_generated_filezbhafqmt.py", line 105, in loop_body
    ag__.if_stmt(ag__.ld(k) in ag__.ld(output_features), if_body_3, else_body_3, get_state_3, set_state_3, ('v',), 1)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1341, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1394, in _py_if_stmt
    return body() if cond else orelse()
  File "/tmp/__autograph_generated_filezbhafqmt.py", line 63, in if_body_3
    v = ag__.converted_call(ag__.ld(vocab).encode_tf, (ag__.ld(v),), None, fscope)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
    result = converted_f(*effective_args)
  File "/tmp/__autograph_generated_filef9jwq2ra.py", line 13, in tf__encode_tf
    retval_ = ag__.converted_call(ag__.ld(self)._encode_tf, (ag__.ld(s),), None, fscope)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
    result = converted_f(*effective_args)
  File "/tmp/__autograph_generated_filezpl5g8b_.py", line 21, in tf___encode_tf
    retval_ = ag__.converted_call(ag__.ld(self).tf_tokenizer.tokenize, (ag__.ld(s),), None, fscope)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
    result = converted_f(*effective_args)
  File "/tmp/__autograph_generated_filet9vre1mq.py", line 22, in tf__tokenize
    input_tensor = ag__.converted_call(ag__.ld(ragged_tensor).convert_to_tensor_or_ragged_tensor, (ag__.ld(input),), None, fscope)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
    return _call_unconverted(f, args, kwargs, options)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
    return f(*args)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py", line 2683, in convert_to_tensor_or_ragged_tensor
    return ops.convert_to_tensor_v2_with_dispatch(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/tensor_util.py", line 441, in make_tensor_proto
    raise ValueError("None values not supported.")
ValueError: in user code:

    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/utils.py", line 779, in None  *
        lambda arg: fn(arg, *args, **kargs)
    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/preprocessors.py", line 116, in tokenize_impl  *
        v = vocab.encode_tf(v)
    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 114, in encode_tf  *
        return self._encode_tf(s)
    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 413, in _encode_tf  *
        return self.tf_tokenizer.tokenize(s)
    File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow_text/python/ops/sentencepiece_tokenizer.py", line 133, in tokenize  *
        input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)

    ValueError: None values not supported.

  In call to configurable 'train' (<function train at 0x7f79d8db2280>)

I even further tried to work the same with byT5 and the same error occurs:
the following is the error occured using byT5

ValueError: in user code:

    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/utils.py", line 779, in None  *
        lambda arg: fn(arg, *args, **kargs)
    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/preprocessors.py", line 116, in tokenize_impl  *
        v = vocab.encode_tf(v)
    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 114, in encode_tf  *
        return self._encode_tf(s)
    File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 555, in _encode_tf  *
        tf_ids = tf.io.decode_raw(s, tf.uint8) + self._num_special_tokens

    ValueError: Tried to convert 'bytes' to a tensor and failed. Error: None values not supported.

Seg Fault after saving checkpoints

Hi,

I am getting a seg fault sometimes after the model has saved the checkpoint. It is not every checkpoint and seems to be random which checkpoints it crashes after. I am not sure if it is related to issue #340

For example, I am running prompt_tuning/scripts/sst2-demo-xxl.sh, and the output is below.

317 18:14:56.525280 140415323761728 utils.py:138] Saved Numpy Arrays for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/numpy_checkpoints/checkpoint_1104000
I0317 18:14:56.604028 140415323761728 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
I0317 18:14:56.614308 140622481194048 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
I0317 18:14:56.624289 140590966570048 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
I0317 18:14:56.653718 140272509271104 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
Fatal Python error: Segmentation fault


Thread 0x00007fdb1dc01700 (most recent call first):
  File "/home/dptam/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 664 in _sda_value
  File "/home/dptam/.local/lib/python3.8/site-packages/jax/_src/device_array.py", line 266 in __array__
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 447 in <lambda>
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/checkpoint_importer.py", line 84 in get
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57 in run
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 80 in _worker
  File "/usr/lib/python3.8/threading.py", line 870 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap

Thread 0x00007f56809df700 (most recent call first):
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 78 in _worker
  File "/usr/lib/python3.8/threading.py", line 870 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap

  Thread 0x00007f56c7aad700 (most recent call first):
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 78 in _worker
  File "/usr/lib/python3.8/threading.py", line 870 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap
Thread 0x00007fdde29efc40 (most recent call first):
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 693 in _write_array
https://symbolize.stripped_domain/r/?trace=7fdde2e4203b,7fdde2e420bf,e,5ef27540f,e,26f7c5aff,f,b15f59df&map= 
E0317 18:14:57.770066  341059 process_state.cc:1062] RAW: Signal 11 raised at PC: 0x7fdde2e4203b while already in FailureSignalHandler!
E0317 18:14:57.770096  341059 process_state.cc:1065] RAW: tid: 341059 raised new signal
    @                0xf       1440  (unknown)
    @        0x25ed159b0  (unknown)  (unknown)
    @               0x10   76231216  (unknown)
    @        0x261cdc840  (unknown)  (unknown)
    @        0x2dfdd4780  (unknown)  (unknown)
    @        0x5f1f8a120  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7fdde301ffd3,7fddd98d57f9,7fdde2e420bf,7,e,25ed159af,f,261cdc83f,2dfdd477f,5f1f8a11f&map=7a511a57244151c993b16b37978e7ed7:7fddcaefd000-7fddd9c3fd50 
E0317 18:14:57.818885  341068 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked.
E0317 18:14:57.818900  341068 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start.
E0317 18:14:57.818919  341068 client.cc:221] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0317 18:14:57.818922  341068 coredump_hook.cc:473] RAW: Sending fingerprint to remote end.
E0317 18:14:57.818928  341068 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0317 18:14:57.818933  341068 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0317 18:14:57.818938  341068 coredump_hook.cc:550] RAW: Discarding core.
prompt_tuning/scripts/sst2-demo-xxl.sh: line 37: 337643 Segmentation fault      (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.CLASS_LABELS="['positive', 'negative']" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_212_000" --gin.USE_CACHED_TASKS="False" --gin.BATCH_SIZE="16" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --tfds_data_dir=${TFDS_DATA_DIR}
##### Command execution on worker 3 failed with return code 139. Continuing.
prompt_tuning/scripts/sst2-demo-xxl.sh: line 37: 334750 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.CLASS_LABELS="['positive', 'negative']" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_212_000" --gin.USE_CACHED_TASKS="False" --gin.BATCH_SIZE="16" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --tfds_data_dir=${TFDS_DATA_DIR}
##### Command execution on worker 1 failed with return code 134. Continuing.
prompt_tuning/scripts/sst2-demo-xxl.sh: line 37: 335504 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.CLASS_LABELS="['positive', 'negative']" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_212_000" --gin.USE_CACHED_TASKS="False" --gin.BATCH_SIZE="16" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --tfds_data_dir=${TFDS_DATA_DIR}
##### Command execution on worker 0 failed with return code 134. Continuing.

Thanks

Finetuning model aborts

I am finetuning T5x on the Winograd schema challenge according to the directions on the tutorial page. When I run my fine-tuning script, the model produced the error below:

I0820 21:51:52.153287 140241781972032 gin_utils.py:65] network.Transformer.config = @network.T5Config()
I0820 21:51:52.154107 140241781972032 partitioning.py:474] `activation_partitioning_dims` = 1, `parameter_partitioning_dims` = 1
I0820 21:51:52.172375 140241781972032 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0820 21:51:52.172535 140241781972032 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0820 21:51:52.172620 140241781972032 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2022-08-20 21:51:52.172999: F external/org_tensorflow/tensorflow/core/tpu/tpu_library_init_fns.inc:100] TpuEmbeddingEngine_CollateMemory not available in this library.
Fatal Python error: Aborted

Thread 0x00007f8a83e09700 (most recent call first):
  File "/usr/lib/python3.8/threading.py", line 306 in wait
  File "/usr/lib/python3.8/threading.py", line 558 in wait
  File "/usr/lib/python3.8/threading.py", line 1252 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap

Current thread 0x00007f8c9598dc40 (most recent call first):
  File "/home/l3atbc/.local/lib/python3.8/site-packages/jaxlib/xla_client.py", line 110 in make_tpu_client
  File "/home/l3atbc/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 190 in tpu_client_timer_callback
  File "/home/l3atbc/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 377 in _init_backend
  File "/home/l3atbc/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 326 in backends
  File "/home/l3atbc/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 401 in _get_backend_uncached
  File "/home/l3atbc/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 417 in get_backend
  File "/home/l3atbc/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 527 in process_index
  File "t5x/train.py", line 176 in train
  File "/home/l3atbc/.local/lib/python3.8/site-packages/gin/config.py", line 1582 in gin_wrapper
  File "t5x/train.py", line 741 in _main
  File "t5x/train.py", line 705 in main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 258 in _run_main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 312 in run
  File "/home/l3atbc/t5x/t5x/gin_utils.py", line 107 in run
  File "t5x/train.py", line 745 in <module>
https://symbolize.stripped_domain/r/?trace=7f8c95de000b,7f8c95de008f,7f8b6adef760,7f8b6f2346d2,7f8b6f234c9e,7f8b6cd1ff9c,7f8b6af40e19,7f8b6af25780,5f3988,903aff&map=
*** SIGABRT received by PID 7953 (TID 7953) on cpu 0 from PID 7953; stack trace: ***
PC: @     0x7f8c95de000b  (unknown)  raise
    @     0x7f8b4d3cbed3        992  (unknown)
    @     0x7f8c95de0090  187527472  (unknown)
    @     0x7f8b6adef761        448  tensorflow::tpu::(anonymous namespace)::SetTpuOpsStructFns()
    @     0x7f8b6f2346d3        128  tensorflow::tpu::InitializeTpuLibrary()
    @     0x7f8b6f234c9f        592  tensorflow::tpu::FindAndLoadTpuLibrary()
    @     0x7f8b6cd1ff9d       1008  xla::GetTpuClient()
    @     0x7f8b6af40e1a        208  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f8b6af25781        768  pybind11::cpp_function::dispatcher()
    @           0x5f3989  (unknown)  PyCFunction_Call
    @           0x903b00  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f8c95de000b,7f8b4d3cbed2,7f8c95de008f,7f8b6adef760,7f8b6f2346d2,7f8b6f234c9e,7f8b6cd1ff9c,7f8b6af40e19,7f8b6af25780,5f3988,903aff&map=4707200934d6baba849e70d87aad4e2c:7f8b390c1000-7f8b4d743eb0
E0820 21:51:52.231133    7953 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked.
E0820 21:51:52.231150    7953 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start.
E0820 21:51:52.231171    7953 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0820 21:51:52.231180    7953 coredump_hook.cc:473] RAW: Sending fingerprint to remote end.
E0820 21:51:52.231189    7953 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0820 21:51:52.231201    7953 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0820 21:51:52.231209    7953 coredump_hook.cc:550] RAW: Discarding core.
E0820 21:51:52.269138    7953 process_state.cc:765] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

Expected Output:
The model fine tunes without aborting.

Exporting models

Are there any instructions available on how to export the T5x models to other formats, for instance to PyTorch (or to another format that then can be exported to PyTorch)? I am trying to export a finetuned byT5-model.

Dataset seeking for restarting from a crash run

I wrote some hacky support for HuggingFace datasets using seqio.FunctionDataSource, specifically for pretraining and finetuning pretrained models.

def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
    dataset = load_dataset(**dataset_params)
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:  # TODO: add for...loop over num_epochs
        for item in dataset[str(split)]:
            yield item[column]

def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
    )

dataset_name = 'NbAiLab/NCC'
dataset_params = {"path": dataset_name, "streaming": True}
dataset_shapes = {"train": 20830348, "validation": 473079}
source = seqio.FunctionDataSource(
    dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
    splits=("train", "validation"),
    caching_permitted=False,
    num_input_examples=dataset_shapes,
)

But unfortunately, as I face constant random crashes during training (#366), I need a way to seek to the right dataset batch to properly continue training.

I see there's a continue_from_last_checkpoint variable in get_dataset(), bit it seems is not used for anything yet.

Is there a way to pass in the needed information to get_dataset_fn() so I can write the logic without using any hard-coded global variables?

DuplicateFlagError when first running the code

Hello, I am trying to run the code on my GPU server. However, after I installed the package, I got such an error.

This is the error msg, could you help me with this?

Traceback (most recent call last):
  File "/home/guangtao/t5x/t5x/train.py", line 617, in <module>
    gin_utils.run(main)
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/t5x/gin_utils.py", line 103, in run
    app.run(
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/guangtao/t5x/t5x/train.py", line 596, in main
    _main(argv)
  File "/home/guangtao/t5x/t5x/train.py", line 609, in _main
    gin_utils.parse_gin_flags(
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/t5x/gin_utils.py", line 50, in parse_gin_flags
    import t5.data  # pylint:disable=unused-import,g-import-not-at-top
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/t5/__init__.py", line 19, in <module>
    import t5.models
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/t5/models/__init__.py", line 17, in <module>
    import t5.models.mesh_transformer
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/t5/models/mesh_transformer.py", line 22, in <module>
    import mesh_tensorflow.transformer.dataset as transformer_dataset
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/mesh_tensorflow/transformer/__init__.py", line 29, in <module>
    import mesh_tensorflow.transformer.utils
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/mesh_tensorflow/transformer/utils.py", line 50, in <module>
    tf.flags.DEFINE_multi_string("gin_file", None, "Path to a Gin file.")
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/absl/flags/_defines.py", line 657, in DEFINE_multi_string
    return DEFINE_multi(
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/absl/flags/_defines.py", line 620, in DEFINE_multi
    return DEFINE_flag(
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/absl/flags/_defines.py", line 140, in DEFINE_flag
    fv[flag.name] = flag
  File "/home/guangtao/.conda/envs/prompt-tuning/lib/python3.10/site-packages/absl/flags/_flagvalues.py", line 439, in __setitem__
    raise _exceptions.DuplicateFlagError.from_flag(name, self)
absl.flags._exceptions.DuplicateFlagError: The flag 'gin_file' is defined twice. First from t5x/train.py, Second from mesh_tensorflow.transformer.utils.  Description from first occurrence: Path to gin configuration file. Multiple paths may be passed and will be imported in the given order, with later configurations  overriding earlier ones.;
    repeat this option to specify a list of values

Provide an interactive class for mixing training and inference.

There's a lot of instructions here about how I could train/fine-tune t5 itself directly on natural language tasks and datasets, but I'm interested in training another model which would take t5 embeds of language as an input.

In order to do this, I would need to interleave t5 inference with my training script, which means I need a way to run t5 on a single batch not just on an entire dataset. How should I go about this?

How does Partitioning Work?

I am working with Data Parallelism on t5x. When the train_step() function is pjit-ted by the following code stub in partitioning.py, is this similar to DistributedDataParallel in PyTorch or is it similar to DataParallel (with a single thread)?

pjitted = pjit(
        fn,
        in_axis_resources=in_axis_resources,
        out_axis_resources=out_axis_resources,
        static_argnums=static_argnums,
        donate_argnums=donate_argnums,
        backend=self._backend)```

byT5

I see that the byT5 model is implemented in T5x but I can not find any byT5 checkpoint to initiate finetuning from. Any chance you could make such a checkpoint available?

Error loading model from checkpoint on Apple M1

I am trying to load longT5 model from checkpoint and getting the following error. Any help is much appreciated.

`

RuntimeError Traceback (most recent call last)
Input In [9], in <cell line: 1>()
----> 1 t5x_checkpoint = t5x.checkpoints.load_t5x_checkpoint(checkpoint_dir)

File ~/t5x/t5x/checkpoints.py:1594, in load_t5x_checkpoint(path, step, state_transformation_fns, remap, restore_dtype, lazy_parameters)
1592 if not lazy_parameters:
1593 future_state_dict = jax.tree_map(lambda x: x.get_async(), state_dict)
-> 1594 state_dict = _run_future_tree(future_state_dict)
1596 if restore_dtype is not None:
1597 state_dict['target'] = _cast(state_dict['target'], restore_dtype)

File ~/t5x/t5x/checkpoints.py:167, in _run_future_tree(future_tree)
165 # TODO(adarob): Use asyncio.run in py3.7+.
166 loop = asyncio.get_event_loop()
--> 167 leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
168 return jax.tree_unflatten(treedef, leaves)

File ~/opt/miniconda3/lib/python3.9/asyncio/base_events.py:623, in BaseEventLoop.run_until_complete(self, future)
612 """Run until the Future is done.
613
614 If the argument is a coroutine, it is wrapped in a Task.
(...)
620 Return the Future's result, or raise its exception.
621 """
622 self._check_closed()
--> 623 self._check_running()
625 new_task = not futures.isfuture(future)
626 future = tasks.ensure_future(future, loop=self)

File ~/opt/miniconda3/lib/python3.9/asyncio/base_events.py:583, in BaseEventLoop._check_running(self)
581 def _check_running(self):
582 if self.is_running():
--> 583 raise RuntimeError('This event loop is already running')
584 if events._get_running_loop() is not None:
585 raise RuntimeError(
586 'Cannot run the event loop while another loop is running')

RuntimeError: This event loop is already running
`

Incompatibility with jaxlib 0.3.7

The newest version of jax seems to require jaxlib v0.3.7, which breaks the trainer script:

$ ./run_pretrain.sh 
2022-04-16 23:34:03.151271: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 185, in _run_module_as_main
    mod_name, mod_spec, code = _get_module_details(mod_name, _Error)
  File "/usr/lib/python3.8/runpy.py", line 111, in _get_module_details
    __import__(pkg_name)
  File "/data/t5x/t5x/__init__.py", line 17, in <module>
    import t5x.adafactor
  File "/data/t5x/t5x/adafactor.py", line 63, in <module>
    from t5x import utils
  File "/data/t5x/t5x/utils.py", line 41, in <module>
    from t5x import checkpoints
  File "/data/t5x/t5x/checkpoints.py", line 51, in <module>
    from t5x import optimizers
  File "/data/t5x/t5x/optimizers.py", line 36, in <module>
    import optax
  File "/data/venvt5/lib/python3.8/site-packages/optax/__init__.py", line 17, in <module>
    from optax import experimental
  File "/data/venvt5/lib/python3.8/site-packages/optax/experimental/__init__.py", line 20, in <module>
    from optax._src.experimental.complex_valued import split_real_and_imaginary
  File "/data/venvt5/lib/python3.8/site-packages/optax/_src/experimental/complex_valued.py", line 32, in <module>
    import chex
  File "/data/venvt5/lib/python3.8/site-packages/chex/__init__.py", line 17, in <module>
    from chex._src.asserts import assert_axis_dimension
  File "/data/venvt5/lib/python3.8/site-packages/chex/_src/asserts.py", line 26, in <module>
    from chex._src import asserts_internal as _ai
  File "/data/venvt5/lib/python3.8/site-packages/chex/_src/asserts_internal.py", line 32, in <module>
    from chex._src import pytypes
  File "/data/venvt5/lib/python3.8/site-packages/chex/_src/pytypes.py", line 40, in <module>
    CpuDevice = jax.lib.xla_extension.CpuDevice
AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'

Forcing the install of "jax[tpu]<0.3.7" works for now.

pip install -U "jax[tpu]<0.3.7" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Data pre-processing

Is there any data pre-processing that is called after batch = next(batch_iter) on line 490 of trainer.py ? If so, where would it be?

How to log in wandb ?

Hey there, i see that tensorboard is the primary logging mechanism here. Perhaps if i wanted to log all the metrics that tensorboard is logging into wandb, how should I do that ?

`[OS error: Too many open files]` when saving checkpoint of model with 48 encoder layers and 48 decoder layers

Hello,
I tried to train a model with num_encoder_layers = 48 and num_decoder_layers = 48. I got the following error:

I0424 03:25:54.615840 140410741775424 checkpoints.py:600] Saving checkpoint for step 5000 to /mnt/disks/persist/t5_training_models/final_exps/pretraining/LIME/t5_48_bs32/checkpoint_5000.tmp-1650770754
Traceback (most recent call last):
  File "./t5x/train.py", line 635, in <module>
    gin_utils.run(main)
  File "/mnt/disks/persist/felix-t5x/t5x/gin_utils.py", line 103, in run
    app.run(
  File "/mnt/disks/persist/envs/t5/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/mnt/disks/persist/envs/t5/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "./t5x/train.py", line 614, in main
    _main(argv)
  File "./t5x/train.py", line 632, in _main
    train_using_gin()
  File "/mnt/disks/persist/envs/t5/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/mnt/disks/persist/envs/t5/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/mnt/disks/persist/envs/t5/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "./t5x/train.py", line 519, in train
    checkpointer.save(trainer.train_state,
  File "/mnt/disks/persist/felix-t5x/t5x/checkpoints.py", line 607, in save
    written_state_dict = self._write_state_to_tensorstore(
  File "/mnt/disks/persist/felix-t5x/t5x/checkpoints.py", line 720, in _write_state_to_tensorstore
    written_state_dict = _run_future_tree(future_written_state)
  File "/mnt/disks/persist/felix-t5x/t5x/checkpoints.py", line 160, in _run_future_tree
    leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
  File "/usr/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
    return future.result()
  File "/mnt/disks/persist/felix-t5x/t5x/checkpoints.py", line 699, in _write_array
    await t[param_info.local_chunk_info.slice].write(arr)
ValueError: Error writing local file "/mnt/disks/persist/t5_training_models/final_exps/pretraining/LIME/t5_48_bs32/checkpoint_5000.tmp-1650770754/target.decoder.layers_36.encoder_decoder_attention.value.kernel/0.0": Failed to open lock file: /mnt/disks/persist/t5_training_models/final_exps/pretraining/LIME/t5_48_bs32/checkpoint_5000.tmp-1650770754/target.decoder.layers_36.encoder_decoder_attention.value.kernel/0.0.__lock [OS error: Too many open files]
  In call to configurable 'train' (<function train at 0x7fb1e5e8bca0>)
/mnt/disks/persist/t5_training_models/final_exps/pretraining/LIME/t5_48_bs32

Thanks for any help.

Pretraining from scratch

Thanks a lot for releasing this code. We have been trying to train a Norwegian T5-base from the sample Flax T5 (v1.1) implementation by the HuggingFace team. We have however been struggling with some instabilities when finetuning this models described in this issue, and were just looking into trying another implementation. Your implementation was released at a perfect time.

Could you provide some more info about how to use this library for pretraining a T5 for another language from scratch?

And particularly how the seqIO sentencepiece tokenizer model is implemented? Can this simply be replaced with the SentencePiece tokenizer trained by the SentencePieceUnigramTokenizer from HuggingFace Tokenizer.

Adafactor HParamMap factor rule definitions missing?

Hi,

I am trying to run the sample code for scalable_t5 on a TPU VM (modified to a span corruption task). However, I am running into the following error message:

ValueError: A parameter with rank strictly higher than 2 must have an explicit factorization rule: decoder/layers/encode
r_decoder_attention/key/kernel, (768, 12, 12, 64)     

From what I can tell, it looks like some additional configuration should be provided in the Gin files for configuring Adafactor?

Thanks!

Fail to run pretrain from sample code

Following this sample code:
https://github.com/google-research/t5x/blob/main/docs/usage/pretrain.md

I got this error:
ValueError: The version of the dataset you are trying to use (c4/en/2.2.0) is too old for this version of TFDS so cannot be generated.Either sync to a previous version of TFDS to first prepare the data or use another version of the dataset. Available for download_and_prepare: ['3.0.1']

Would it be possible to update the database reference?
https://github.com/google-research/text-to-text-transfer-transformer/blob/259358fab5dd8bf523234110b89c22b751174694/t5/data/tasks.py#L47

FAIL to run "Example: English to German translation" from README file

Compute environment:

  • Google Cloud TPU VM (base image)
  • 1 x TPU V3-8
  • Default installation, according to README file

Storage:

  • I tried both with local disks (from TPU VM) and on Google Cloud Storage

Running this example, as described in the README file, I got the following error:

...
I0620 18:55:27.067202 139969230261312 train.py:545] Epoch 0 of 100
I0620 18:55:27.067452 139952425035520 logging_writer.py:48] [0] collection=train timing/compilation_seconds=80.777680
I0620 18:55:27.068523 139969230261312 train.py:551] BEGIN Train loop.
I0620 18:55:27.069384 139969230261312 train.py:556] Training for 500 steps.
I0620 18:55:27.071459 139969230261312 trainer.py:487] Training: step 0
I0620 18:55:37.198480 139969230261312 trainer.py:487] Training: step 46
I0620 18:55:47.225553 139969230261312 trainer.py:487] Training: step 90
I0620 18:55:57.255041 139969230261312 trainer.py:487] Training: step 134
I0620 18:56:07.280249 139969230261312 trainer.py:487] Training: step 178
I0620 18:56:17.305412 139969230261312 trainer.py:487] Training: step 222
I0620 18:56:27.334421 139969230261312 trainer.py:487] Training: step 266
I0620 18:56:37.359880 139969230261312 trainer.py:487] Training: step 310
I0620 18:56:47.386486 139969230261312 trainer.py:487] Training: step 354
I0620 18:56:57.412878 139969230261312 trainer.py:487] Training: step 398
I0620 18:57:07.443236 139969230261312 trainer.py:487] Training: step 442
I0620 18:57:17.466779 139969230261312 trainer.py:487] Training: step 486
I0620 18:57:20.668031 139969230261312 train.py:581] END Train loop.
I0620 18:57:20.668265 139969230261312 train.py:454] Compiling training eval loop.
I0620 18:57:24.448034 139952425035520 logging_writer.py:48] [500] collection=train accuracy=0.213222, cross_ent_loss=143349.328000, cross_ent_loss_per_all_target_tokens=4.374674, learning_rate=0.010000, learning_rate/current=0.009999999776482582, loss=143754.976000, loss_per_all_target_tokens=4.387054, loss_per_nonpadding_target_token=5.047726, nonpadding_fraction=0.869115, timing/seconds=117.008831, timing/seqs=64000, timing/seqs_per_second=546.967261, timing/seqs_per_second_per_core=68.370908, timing/steps_per_second=4.273182, timing/target_tokens_per_second=140023.618929, timing/target_tokens_per_second_per_core=17502.952366, z_loss=405.720469, z_loss_per_all_target_tokens=0.012382
I0620 18:57:36.698853 139951174522624 logging_writer.py:48] [500] collection=training_eval/wmt_t2t_ende_v003 timing/compilation_seconds=15.152848
I0620 18:57:36.700013 139969230261312 train.py:459] Computing training evaluation metrics.
I0620 18:57:37.040851 139969230261312 trainer.py:529] Evaluating: wmt_t2t_ende_v003.
I0620 18:57:39.755579 139951174522624 logging_writer.py:48] [500] collection=training_eval/wmt_t2t_ende_v003 accuracy=0.298345, cross_ent_loss=119181.912500, cross_ent_loss_per_all_target_tokens=3.637143, loss=119662.062500, loss_per_all_target_tokens=3.651796, loss_per_nonpadding_target_token=4.061035, nonpadding_fraction=0.899228, timing/seconds=2.427278, timing/seqs=2560, timing/seqs_per_second=1054.679339, timing/seqs_per_second_per_core=131.834917, timing/steps_per_second=8.239682, timing/target_tokens_per_second=269997.910706, timing/target_tokens_per_second_per_core=33749.738838, z_loss=480.150391, z_loss_per_all_target_tokens=0.014653
I0620 18:57:39.886704 139969230261312 train.py:473] Running inference evaluation.
Traceback (most recent call last):
  File "/home/renatoleite/workspace/t5x/t5x/train.py", line 695, in <module>
    gin_utils.run(main)
  File "/home/renatoleite/workspace/t5x/t5x/gin_utils.py", line 107, in run
    app.run(
  File "/home/renatoleite/.local/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/renatoleite/.local/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/renatoleite/workspace/t5x/t5x/train.py", line 673, in main
    _main(argv)
  File "/home/renatoleite/workspace/t5x/t5x/train.py", line 693, in _main
    train_using_gin()
  File "/home/renatoleite/.local/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/renatoleite/.local/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/renatoleite/.local/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/renatoleite/workspace/t5x/t5x/train.py", line 616, in train
    _run_inference_eval()
  File "/home/renatoleite/workspace/t5x/t5x/train.py", line 475, in _run_inference_eval
    all_metrics, _, _ = evaluator.evaluate(
TypeError: evaluate() got an unexpected keyword argument 'predict_with_aux_fn'
  In call to configurable 'train' (<function train at 0x7f4b7ae3c8b0>)

I checked the .gin file and the variables seems to be set correctly.
Any idea what could be generating this error?

how to use a different optimizer?

There is an excellent example about how to use a different optimizer for a model trained from scratch. But, I am getting an error if I want to use a different optimizer for fine-tuning.

RuntimeError: UNIMPLEMENTED: Requested AllReduceStart not implemented on GPU

I am trying to run one of the fine-tuning examples on a machine with 2 GPUs and getting the following error:

RuntimeError: UNIMPLEMENTED: Requested AllReduceStart not implemented on GPU; replica_count: 1; partition_count: 2, group_mode: kCrossReplicaAndPartition, operand_count: 49; NCCL support: 1; first operand array element-type: BF16
In call to configurable 'train' (<function train at 0x7fed3c31e3a0>)

full error trace below:


RuntimeError Traceback (most recent call last)
Input In [19], in <cell line: 24>()
17 gin_utils.parse_gin_flags(
18 # User-provided gin paths take precedence if relative paths conflict.
19 FLAGS.gin_search_paths,# + _DEFAULT_GIN_SEARCH_PATHS,
20 FLAGS.gin_file,
21 FLAGS.gin_bindings)
22 train_using_gin()
---> 24 gin_utils.run(main_train)

File /anaconda/envs/py39/lib/python3.9/site-packages/t5x/gin_utils.py:105, in run(main)
103 def run(main):
104 """Wrapper for app.run that rewrites gin args before parsing."""
--> 105 app.run(
106 main,
107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))

File /anaconda/envs/py39/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser)
310 callback()
311 try:
--> 312 _run_main(main, args)
313 except UsageError as error:
314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

File /anaconda/envs/py39/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv)
256 sys.exit(retval)
257 else:
--> 258 sys.exit(main(argv))

Input In [18], in main_train(argv)
1 def main_train(argv: Sequence[str]):
2 """Wrapper for pdb post mortems."""
----> 3 _main(argv)

Input In [19], in _main(argv)
15 train_using_gin = gin.configurable(train)
17 gin_utils.parse_gin_flags(
18 # User-provided gin paths take precedence if relative paths conflict.
19 FLAGS.gin_search_paths,# + _DEFAULT_GIN_SEARCH_PATHS,
20 FLAGS.gin_file,
21 FLAGS.gin_bindings)
---> 22 train_using_gin()

File /anaconda/envs/py39/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
1604 err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1605 utils.augment_exception_message_and_reraise(e, err_str)

File /anaconda/envs/py39/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message)
39 proxy = ExceptionProxy()
40 ExceptionProxy.qualname = type(exception).qualname
---> 41 raise proxy.with_traceback(exception.traceback) from None

File /anaconda/envs/py39/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1579 new_kwargs.update(kwargs)
1581 try:
-> 1582 return fn(*new_args, **new_kwargs)
1583 except Exception as e: # pylint: disable=broad-except
1584 err_str = ''

Input In [10], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda)
422 logging.info('Compiling train loop.')
423 logging.flush()
--> 424 trainer.compile_train(first_batch)
426 # Main Loop over "epochs".
427 for epoch in range(first_epoch, num_epochs):

File /anaconda/envs/py39/lib/python3.9/site-packages/t5x/trainer.py:545, in BaseTrainer.compile_train(self, batch)
532 """Pre-compiles train step (if not yet compiled).
533
534 Not required.
(...)
542 shapes and dtypes.
543 """
544 tick = time.time()
--> 545 self._compiled_train_step = self._partitioner.compile(
546 self._partitioned_train_step, self.train_state, batch)
547 tock = time.time()
548 self.train_metrics_manager.write_scalar("timing/compilation_seconds",
549 tock - tick, self.train_state.step)

File /anaconda/envs/py39/lib/python3.9/site-packages/t5x/partitioning.py:795, in BasePjitPartitioner.compile(self, partitioned_fn, *args)
793 def compile(self, partitioned_fn: PjittedFnWithContext,
794 *args) -> CompiledPartitionedCallable:
--> 795 return partitioned_fn.lower(*args).compile()

File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/stages.py:221, in Lowered.compile(self)
220 def compile(self) -> Compiled:
--> 221 return Compiled(self._lowering.compile(), self.args_info,
222 self.out_tree, no_kwargs=self._no_kwargs)

File /anaconda/envs/py39/lib/python3.9/site-packages/jax/interpreters/pxla.py:2346, in MeshComputation.compile(self, _allow_propagation_to_outputs, _allow_compile_replicated)
2342 def compile(self,
2343 _allow_propagation_to_outputs : bool = False,
2344 _allow_compile_replicated : bool = True) -> 'MeshExecutable':
2345 if self._executable is None:
-> 2346 self._executable = MeshExecutable.from_hlo(
2347 self._name, self._hlo, **self.compile_args,
2348 _allow_propagation_to_outputs=_allow_propagation_to_outputs,
2349 _allow_compile_replicated=_allow_compile_replicated) # type: ignore
2350 return self._executable

File /anaconda/envs/py39/lib/python3.9/site-packages/jax/interpreters/pxla.py:2456, in MeshExecutable.from_hlo(name, computation, mesh, global_in_avals, global_out_avals, in_axes, out_axes, spmd_lowering, tuple_args, in_is_global, auto_spmd_lowering, _allow_propagation_to_outputs, _allow_compile_replicated)
2453 else:
2454 with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
2455 "in {elapsed_time} sec"):
-> 2456 xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options)
2458 if auto_spmd_lowering:
2459 in_axes, out_axes = _get_array_mapping_from_executable(xla_executable, mesh)

File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/dispatch.py:664, in compile_or_get_cached(backend, computation, compile_options)
661 ir_str = (computation if isinstance(computation, str)
662 else computation.as_hlo_text())
663 _dump_ir_to_file(module_name, ir_str)
--> 664 return backend_compile(backend, computation, compile_options)

File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/profiler.py:206, in annotate_function..wrapper(*args, **kwargs)
203 @wraps(func)
204 def wrapper(*args, **kwargs):
205 with TraceAnnotation(name, **decorator_kwargs):
--> 206 return func(*args, **kwargs)
207 return wrapper

File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/dispatch.py:618, in backend_compile(backend, built_c, options)
614 @profiler.annotate_function
615 def backend_compile(backend, built_c, options):
616 # we use a separate function call to ensure that XLA compilation appears
617 # separately in Python profiling results
--> 618 return backend.compile(built_c, compile_options=options)

RuntimeError: UNIMPLEMENTED: Requested AllReduceStart not implemented on GPU; replica_count: 1; partition_count: 2, group_mode: kCrossReplicaAndPartition, operand_count: 49; NCCL support: 1; first operand array element-type: BF16
In call to configurable 'train' (<function train at 0x7fed3c31e3a0>)

Convert script not running

It is not currently possible to run the MT->T5X convertion script directly like this:

python -m t5x.scripts.convert_tf_checkpoint \
 --gin_file=t5x/examples/t5/t5_1_0/small.gin\
 --gin.convert_checkpoint.model=%MODEL\
 --gin.convert_checkpoint.tf_checkpoint_path=\
\"gs://t5-data/pretrained_models/small/model.ckpt-1000000\"\
 --gin.convert_checkpoint.output_dir=\"/tmp/t5x_checkpoints/t5_small\"

You are getting the following error:

 raise ValueError(fmt.format(binding_key))                                                                                                                                                    
ValueError: DROPOUT_RATE/macro.value set to `%gin.REQUIRED` but not subsequently overridden.

Setting DROPOUT_RATE manually, for instance at the end of the small.gin-file, is a workaround. Not sure what is a clean way of fixing the issue.

partitioning issues during inference on v3-32

Hi,

I was running inference on prompt-tuning which I think this calls this codebase and I ran into an issue when doing inference on a v3-32 with the partitioning with TypeError: 'ShapeDtypeStruct' object is not iterable. Training works fine on a v3-32, and training and inference work fine on a v3-8.

Here is the traceback.

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 234, in <module>
    gin_utils.run(main)
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/gin_utils.py", line 103, in run
    app.run(
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 213, in main
    _main(argv)
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 231, in _main
    evaluate_using_gin()
  File "/home/dptam/.local/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/dptam/.local/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/dptam/.local/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 127, in evaluate
    train_state_initializer = utils.TrainStateInitializer(
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/utils.py", line 365, in __init__
    self.train_state_axes = partitioner.get_mesh_axes(
  File "/home/dptam/.local/lib/python3.8/site-packages/t5x/partitioning.py", line 826, in get_mesh_axes
    mesh_axes_dict = jax.tree_map(flax_partitioning.logical_to_mesh_axes,
  File "/home/dptam/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/dptam/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/dptam/.local/lib/python3.8/site-packages/flax/linen/partitioning.py", line 154, in logical_to_mesh_axes
    axis_name_counts = collections.Counter(array_dim_names)
  File "/usr/lib/python3.8/collections/__init__.py", line 552, in __init__
    self.update(iterable, **kwds)
  File "/usr/lib/python3.8/collections/__init__.py", line 637, in update
    _count_elements(self, iterable)
TypeError: 'ShapeDtypeStruct' object is not iterable
  In call to configurable 'evaluate' (<function evaluate at 0x7f784d161700>)
Rewritten gin arg: --gin_bindings=MIXTURE_OR_TASK_NAME = 'glue_rte_32_shot_32_seed'
Rewritten gin arg: --gin_bindings=MIXTURE_OR_TASK_MODULE = 'prompt_tuning.data.few_glue'
Rewritten gin arg: --gin_bindings=TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 8}
Rewritten gin arg: --gin_bindings=CHECKPOINT_PATH = 'gs://nicl/pretrained_models/t5x_checkpoints/t0_3b/checkpoint_1112000'
Rewritten gin arg: --gin_bindings=EVAL_OUTPUT_DIR = 'gs://nicl/checkpoint_models/rte/32_shot/32_seed/prompt-tuning/t0-3b/eval'
Rewritten gin arg: --gin_bindings=utils.DatasetConfig.split = 'validation'
Rewritten gin arg: --gin_bindings=utils.DatasetConfig.batch_size = 128
Rewritten gin arg: --gin_bindings=USE_CACHED_TASKS = False
Rewritten gin arg: --gin_bindings=partitioning.ModelBasedPjitPartitioner.model_parallel_submesh = (4, 4, 1, 2)
Rewritten gin arg: --gin_bindings=PROMPT_FILE = 'gs://nicl/checkpoint_models/rte/32_shot/32_seed/prompt-tuning/t0-3b/numpy_checkpoints/checkpoint_1112300/encoder.prompt.prompt.prompt'
##### Command execution on worker 0 failed with return code 1. Continuing.
##### Command execution on worker 3 failed with return code 1. Continuing.
##### Command execution on worker 1 failed with return code 1. Continuing.
##### Command execution on worker 2 failed with return code 1. Continuing.

TypeError: get_dataset() got an unexpected keyword argument 'trim_output_features'

Hey the current release has issues when pretraining or finetuning. I actully had a custom task to pretrain. upon launching the training i get the error TypeError: get_dataset() got an unexpected keyword argument 'trim_output_features'

Hence i tried the simple wmt finetuning task from the documentation. even that tasks faces the same error

the following is detailed error:

File "/home/stephen/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/stephen/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/stephen/Desktop/t5_latest/t5x/t5x/train.py", line 746, in <module>
    gin_utils.run(main)
  File "/home/stephen/Desktop/t5_latest/t5x/t5x/gin_utils.py", line 107, in run
    app.run(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/stephen/Desktop/t5_latest/t5x/t5x/train.py", line 709, in main
    _main(argv)
  File "/home/stephen/Desktop/t5_latest/t5x/t5x/train.py", line 744, in _main
    train_using_gin()
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/stephen/Desktop/t5_latest/t5x/t5x/train.py", line 247, in train
    train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
  File "/home/stephen/Desktop/t5_latest/t5x/t5x/utils.py", line 1366, in get_dataset
    return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
  File "/home/stephen/Desktop/t5_latest/t5x/t5x/utils.py", line 1386, in get_dataset_inner
    ds = seqio.get_dataset(
TypeError: get_dataset() got an unexpected keyword argument 'trim_output_features'
  In call to configurable 'train' (<function train at 0x7ff9e1d67af0>)

further i tried removing the trim_output_features args
upon removing the trim_output_features=cfg.trim_output_features args from seqio.get_dataset() from the get_dataset_inner() function

i get the error: TypeError: Can't instantiate abstract class LegacyCheckpointer with abstract methods async_restore, async_save In call to configurable 'train' (<function train at 0x7f3f62ce8af0>)

Incorrect checkpoint path

  File "/home/torinaki/src/product-description-generation/t5x/t5x/utils.py", line 472, in from_checkpoints
tensorstore/internal/oauth2/google_auth_provider.cc:173: Using credentials at bigquery-key.json
    yield _restore_path(path, restore_cfg)
  File "/home/torinaki/src/product-description-generation/t5x/t5x/utils.py", line 461, in _restore_path
    return restore_checkpointer.restore(
  File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 811, in restore
    state_dict = self._read_state_from_tensorstore(
  File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 860, in _read_state_from_tensorstore
    state_dict = _run_future_tree(future_state_dict)
  File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 160, in _run_future_tree
    leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
  File "/home/torinaki/.pyenv/versions/3.8.9/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
    return future.result()
  File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoint_importer.py", line 115, in _get_and_cast
    arr = await self._get_fn()  # pytype: disable=bad-return-type
  File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 1190, in _read_ts
tensorstore/internal/oauth2/google_auth_provider.cc:189: Using ServiceAccount AuthProvider
    t = await ts.open(tmp_ts_spec_dict, open=True)
ValueError: Error opening "zarr" driver: Metadata at "gs://t5x-dummy-bucket/gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000/target.decoder.layers_0.self_attention.value.kernel/.zarray" does not exist
  In call to configurable 'train' (<function train at 0x7f8b444b7c10>)

It seems that the problem is in: https://github.com/google-research/t5x/blob/main/t5x/checkpoints.py#L219

Getting Logits of Tokens

In inference, I want to get the logit over vocabulary for the first generated token. How can I do that?

Import error with flax.optim

I'm trying to train T5X on the Winograd schema challenge. I run the training script as instructed on the tutorial but receive this error.

ImportError: cannot import name 'optim' from 'flax' (/home/[name]/.local/lib/python3.8/site-packages/flax/__init__.py)

Multi-node GPU training

Can you provide some documentation on using T5X models on GPUs?
I can see some example in original T5 repository but not here.
Also does T5X support training on multi-node GPU cluster?

XLA fails to compile model

I'm trying to follow instructions to fine-tune the translation model: https://github.com/google-research/t5x#fine-tuning
Unfortunately, I'm getting an error:

RuntimeError: UNIMPLEMENTED: While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %bitcast-convert.4 = u64[2]{0} bitcast-convert(u32[2,2]{1,0} %reshape.3), metadata={op_type="rng_bit_generator" op_name="jit(rng_bit_generator)/rng_bit_generator[\n  algorithm=RNG_DEFAULT\n  dtype=uint32\n  shape=(20, 4)\n]" source_file="/home/src/t5x/t5x/train.py" source_line=215}

Installing dependencies in Colab leads to excessive tensorflow & tf-nightly installation

The following chunk leads to excessive installation time and disk space consumption in Google Colab (more than usual). It appears to be the cause of issue 46 from Magenta's MT3 repo.

!git clone --branch=main https://github.com/google-research/t5x
%cd t5x

!python3 -m pip install -e '.[tpu]' -f \
  https://storage.googleapis.com/jax-releases/libtpu_releases.html

Output (manually aborted after a few hours):

Cloning into 't5x'...
remote: Enumerating objects: 2584, done.
remote: Counting objects: 100% (316/316), done.
remote: Compressing objects: 100% (147/147), done.
remote: Total 2584 (delta 193), reused 224 (delta 169), pack-reused 2268
Receiving objects: 100% (2584/2584), 6.69 MiB | 15.68 MiB/s, done.
Resolving deltas: 100% (1785/1785), done.
/content/t5x
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Obtaining file:///content/t5x
Collecting clu@ git+https://github.com/google/CommonLoopUtils#egg=clu
  Cloning https://github.com/google/CommonLoopUtils to /tmp/pip-install-u4q2bj3y/clu_2518408ae45f4c0f94de081d5bd84d2e
  Running command git clone -q https://github.com/google/CommonLoopUtils /tmp/pip-install-u4q2bj3y/clu_2518408ae45f4c0f94de081d5bd84d2e
Collecting flax@ git+https://github.com/google/flax#egg=flax
  Cloning https://github.com/google/flax to /tmp/pip-install-u4q2bj3y/flax_452365d37fce4708a516ee532109db2a
  Running command git clone -q https://github.com/google/flax /tmp/pip-install-u4q2bj3y/flax_452365d37fce4708a516ee532109db2a
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from t5x==0.0.0) (1.0.0)
Requirement already satisfied: cached_property in /usr/local/lib/python3.7/dist-packages (from t5x==0.0.0) (1.5.2)
Requirement already satisfied: gin-config in /usr/local/lib/python3.7/dist-packages (from t5x==0.0.0) (0.5.0)
Requirement already satisfied: jax>=0.2.27 in /usr/local/lib/python3.7/dist-packages (from t5x==0.0.0) (0.3.8)
Requirement already satisfied: jaxlib>=0.1.76 in /usr/local/lib/python3.7/dist-packages (from t5x==0.0.0) (0.3.7+cuda11.cudnn805)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from t5x==0.0.0) (1.21.6)
Collecting orbax
  Downloading orbax-0.0.1-py3-none-any.whl (34 kB)
Collecting seqio-nightly
  Downloading seqio_nightly-0.0.7.dev20220524-py3-none-any.whl (301 kB)
     |████████████████████████████████| 301 kB 11.8 MB/s 
Collecting t5
  Downloading t5-0.9.3-py3-none-any.whl (153 kB)
     |████████████████████████████████| 153 kB 47.5 MB/s 
Requirement already satisfied: tensorflow in /usr/local/lib/python3.7/dist-packages (from t5x==0.0.0) (2.8.0+zzzcolab20220506162203)
Collecting tensorstore>=0.1.20
  Downloading tensorstore-0.1.20-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.0 MB)
     |████████████████████████████████| 9.0 MB 21.3 MB/s 
Collecting etils[epath]
  Downloading etils-0.5.1-py3-none-any.whl (87 kB)
     |████████████████████████████████| 87 kB 5.5 MB/s 
Collecting ml_collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
     |████████████████████████████████| 77 kB 5.6 MB/s 
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (21.3)
Requirement already satisfied: tensorflow_datasets in /usr/local/lib/python3.7/dist-packages (from clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (4.0.1)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (3.2.2)
Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (1.0.3)
Collecting optax
  Downloading optax-0.1.2-py3-none-any.whl (140 kB)
     |████████████████████████████████| 140 kB 36.8 MB/s 
Collecting rich~=11.1.0
  Downloading rich-11.1.0-py3-none-any.whl (216 kB)
     |████████████████████████████████| 216 kB 43.2 MB/s 
Requirement already satisfied: typing_extensions>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (4.2.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.27->t5x==0.0.0) (3.3.0)
Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.27->t5x==0.0.0) (1.4.1)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.27->t5x==0.0.0) (2.23.0)
Collecting libtpu-nightly==0.1.dev20220415
  Downloading https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220415-py3-none-any.whl (184.0 MB)
     |████████████████████████████████| 184.0 MB 21 kB/s 
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.76->t5x==0.0.0) (2.0)
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
     |████████████████████████████████| 51 kB 5.6 MB/s 
Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich~=11.1.0->flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (2.6.1)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->t5x==0.0.0) (1.15.0)
Collecting tf-nightly
  Downloading tf_nightly-2.10.0.dev20220524-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.4 MB)
     |████████████████████████████████| 503.4 MB 29 kB/s 
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (5.7.1)
Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (3.8.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (1.4.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (0.11.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (2.8.2)
Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from ml_collections->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (3.13)
Requirement already satisfied: contextlib2 in /usr/local/lib/python3.7/dist-packages (from ml_collections->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (0.5.5)
Collecting chex>=0.0.4
  Downloading chex-0.1.3-py3-none-any.whl (72 kB)
     |████████████████████████████████| 72 kB 509 kB/s 
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (0.1.7)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax@ git+https://github.com/google/flax#egg=flax->t5x==0.0.0) (0.11.2)
Collecting dataclasses
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->jax>=0.2.27->t5x==0.0.0) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->jax>=0.2.27->t5x==0.0.0) (2022.5.18.1)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->jax>=0.2.27->t5x==0.0.0) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->jax>=0.2.27->t5x==0.0.0) (3.0.4)
Collecting tfds-nightly
  Downloading tfds_nightly-4.5.2.dev202205240044-py3-none-any.whl (4.3 MB)
     |████████████████████████████████| 4.3 MB 29.0 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
     |████████████████████████████████| 1.2 MB 51.5 MB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
     |████████████████████████████████| 4.6 MB 46.0 MB/s 
Collecting seqio
  Downloading seqio-0.0.7-py3-none-any.whl (286 kB)
     |████████████████████████████████| 286 kB 56.1 MB/s 
Collecting mesh-tensorflow[transformer]>=0.1.13
  Downloading mesh_tensorflow-0.1.21-py3-none-any.whl (385 kB)
     |████████████████████████████████| 385 kB 59.7 MB/s 
Requirement already satisfied: editdistance in /usr/local/lib/python3.7/dist-packages (from t5->t5x==0.0.0) (0.5.3)
Requirement already satisfied: babel in /usr/local/lib/python3.7/dist-packages (from t5->t5x==0.0.0) (2.10.1)
Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from t5->t5x==0.0.0) (1.11.0+cu113)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from t5->t5x==0.0.0) (1.0.2)
Collecting transformers>=2.7.0
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
     |████████████████████████████████| 4.2 MB 43.5 MB/s 
Collecting sacrebleu
  Downloading sacrebleu-2.1.0-py3-none-any.whl (92 kB)
     |████████████████████████████████| 92 kB 9.4 MB/s 
Collecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from t5->t5x==0.0.0) (3.2.5)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from t5->t5x==0.0.0) (1.3.5)
Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from mesh-tensorflow[transformer]>=0.1.13->t5->t5x==0.0.0) (0.16.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers>=2.7.0->t5->t5x==0.0.0) (3.7.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers>=2.7.0->t5->t5x==0.0.0) (2019.12.20)
Collecting PyYAML
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
     |████████████████████████████████| 596 kB 41.4 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
     |████████████████████████████████| 86 kB 4.2 MB/s 
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers>=2.7.0->t5->t5x==0.0.0) (4.11.3)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
     |████████████████████████████████| 6.6 MB 42.1 MB/s 
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers>=2.7.0->t5->t5x==0.0.0) (4.64.0)
Requirement already satisfied: pytz>=2015.7 in /usr/local/lib/python3.7/dist-packages (from babel->t5->t5x==0.0.0) (2022.1)
Collecting portalocker
  Downloading portalocker-2.4.0-py2.py3-none-any.whl (16 kB)
Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from sacrebleu->t5->t5x==0.0.0) (0.8.9)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->t5->t5x==0.0.0) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->t5->t5x==0.0.0) (1.1.0)
Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (1.14.1)
Collecting tf-estimator-nightly==2.8.0.dev2021122109
  Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
     |████████████████████████████████| 462 kB 57.9 MB/s 
Requirement already satisfied: gast>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (0.5.3)
Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (3.17.3)
Requirement already satisfied: tensorboard<2.9,>=2.8 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (2.8.0)
Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (1.6.3)
Requirement already satisfied: libclang>=9.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (14.0.1)
Requirement already satisfied: keras<2.9,>=2.8.0rc0 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (2.8.0)
Requirement already satisfied: keras-preprocessing>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (1.1.2)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (0.26.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (57.4.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (1.46.1)
Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (3.1.0)
Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (0.2.0)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (1.1.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.7/dist-packages (from astunparse>=1.6.0->tensorflow->t5x==0.0.0) (0.37.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (0.6.1)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (1.35.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (3.3.7)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (1.8.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (0.4.6)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (1.0.1)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (0.2.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (1.3.1)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow->t5x==0.0.0) (3.2.0)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.7/dist-packages (from tensorflow_datasets->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (1.8.0)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow_datasets->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (21.4.0)
Requirement already satisfied: promise in /usr/local/lib/python3.7/dist-packages (from tensorflow_datasets->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (2.3)
Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from tensorflow_datasets->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (0.3.5.1)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-metadata->tensorflow_datasets->clu@ git+https://github.com/google/CommonLoopUtils#egg=clu->t5x==0.0.0) (1.56.1)
Requirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-text->seqio-nightly->t5x==0.0.0) (0.12.0)
Collecting tensorflow
  Downloading tensorflow-2.9.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (511.7 MB)
     |████████████████████████████████| 511.7 MB 5.8 kB/s 
  Downloading tensorflow-2.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (511.7 MB)
     |████████████████████████████████| 511.7 MB 4.6 kB/s 
INFO: pip is looking at multiple versions of tensorflow-text to determine which version is compatible with other requirements. This could take a while.
Collecting tensorflow-text
  Downloading tensorflow_text-2.8.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.9 MB)
     |████████████████████████████████| 4.9 MB 42.3 MB/s 
Collecting gast>=0.2.1
  Downloading gast-0.4.0-py3-none-any.whl (9.8 kB)
Collecting flatbuffers<3.0,>=1.12
  Downloading flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Collecting tf-nightly
  Downloading tf_nightly-2.10.0.dev20220521-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (504.1 MB)
     |████████████████████████████████| 504.1 MB 28 kB/s 
  Downloading tf_nightly-2.10.0.dev20220520-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.8 MB)
     |████████████████████████████████| 503.8 MB 4.3 kB/s 
  Downloading tf_nightly-2.10.0.dev20220519-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.7 MB)
     |████████████████████████████████| 503.7 MB 27 kB/s 
  Downloading tf_nightly-2.10.0.dev20220518-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.7 MB)
     |████████████████████████████████| 503.7 MB 31 kB/s 
  Downloading tf_nightly-2.10.0.dev20220517-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.6 MB)
     |████████████████████████████████| 503.6 MB 8.1 kB/s 
  Downloading tf_nightly-2.10.0.dev20220516-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.6 MB)
     |████████████████████████████████| 503.6 MB 30 kB/s 
  Downloading tf_nightly-2.10.0.dev20220515-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.6 MB)
     |████████████████████████████████| 503.6 MB 26 kB/s 
  Downloading tf_nightly-2.10.0.dev20220514-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.6 MB)
     |████████████████████████████████| 503.6 MB 18 kB/s 
  Downloading tf_nightly-2.10.0.dev20220427-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.2 MB)
     |████████████████████████████████| 503.2 MB 3.3 kB/s 
  Downloading tf_nightly-2.10.0.dev20220426-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (503.0 MB)
     |████████████████████████████████| 503.0 MB 26 kB/s 
  Downloading tf_nightly-2.10.0.dev20220425-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (502.9 MB)
     |████████████████████████████████| 502.9 MB 6.8 kB/s 
  Downloading tf_nightly-2.10.0.dev20220424-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (502.9 MB)
     |████████████████████████████████| 502.9 MB 18 kB/s 
  Downloading tf_nightly-2.10.0.dev20220423-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (502.9 MB)
     |████████████████████████████████| 502.9 MB 8.0 kB/s 
  Downloading tf_nightly-2.10.0.dev20220422-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (502.3 MB)
     |████████████████████████████████| 502.3 MB 15 kB/s 
  Downloading tf_nightly-2.10.0.dev20220421-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (514.1 MB)
     |████████████████████████████████| 514.1 MB 10 kB/s 
  Downloading tf_nightly-2.10.0.dev20220420-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (514.6 MB)
     |████████████████████████████████| 514.6 MB 16 kB/s 
  Downloading tf_nightly-2.10.0.dev20220419-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (514.6 MB)
     |████████████████████████████████| 514.6 MB 10 kB/s 
  Downloading tf_nightly-2.10.0.dev20220418-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (513.7 MB)
     |████████████████████████████████| 513.7 MB 27 kB/s 
  Downloading tf_nightly-2.10.0.dev20220417-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (513.7 MB)
     |████████████████████████████████| 513.7 MB 23 kB/s 
  Downloading tf_nightly-2.10.0.dev20220416-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (513.7 MB)
     |████████████████████████████████| 513.7 MB 31 kB/s 
  Downloading tf_nightly-2.10.0.dev20220415-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (513.4 MB)
     |████████████████████████████████| 513.4 MB 3.2 kB/s 
  Downloading tf_nightly-2.10.0.dev20220414-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (513.3 MB)
     |████████████████████████████████| 513.3 MB 17 kB/s 
  Downloading tf_nightly-2.10.0.dev20220413-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (513.2 MB)
     |████████████████████████████████| 513.2 MB 31 kB/s 
  Downloading tf_nightly-2.10.0.dev20220407-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (511.4 MB)
     |████████████████████████████████| 511.4 MB 507 bytes/s 
  Downloading tf_nightly-2.10.0.dev20220406-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (511.3 MB)
     |████████████████████████████████| 511.3 MB 12 kB/s 
  Downloading tf_nightly-2.10.0.dev20220404-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (510.6 MB)
     |████████████████████████████████| 510.6 MB 5.1 kB/s 
  Downloading tf_nightly-2.10.0.dev20220403-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (510.6 MB)
     |████████████████████████████████| 510.6 MB 4.8 kB/s 
  Downloading tf_nightly-2.10.0.dev20220402-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (510.6 MB)
     |████████████████████████████████| 510.6 MB 1.9 kB/s 
  Downloading tf_nightly-2.9.0.dev20220401-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (510.6 MB)
     |████████████████████████████████| 510.6 MB 3.4 kB/s 
Collecting keras-nightly~=2.9.0.dev
  Downloading keras_nightly-2.9.0.dev2022033107-py2.py3-none-any.whl (1.6 MB)
     |████████████████████████████████| 1.6 MB 40.9 MB/s 
Collecting tf-nightly
  Downloading tf_nightly-2.9.0.dev20220329-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (510.1 MB)
     |████████████████████████████████| 510.1 MB 12 kB/s 
  Downloading tf_nightly-2.9.0.dev20220328-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (509.8 MB)
     |████████████████████████████████| 509.8 MB 7.1 kB/s 
  Downloading tf_nightly-2.9.0.dev20220327-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (509.8 MB)
     |████████████████████████████████| 509.8 MB 32 kB/s 
  Downloading tf_nightly-2.9.0.dev20220326-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (509.8 MB)
     |████████████████████████████████| 509.8 MB 24 kB/s 
  Downloading tf_nightly-2.9.0.dev20220325-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (509.7 MB)
     |████████████████████████████████| 509.7 MB 13 kB/s 
  Downloading tf_nightly-2.9.0.dev20220324-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (509.8 MB)
     |████████████████████████████████| 509.8 MB 31 kB/s 
  Downloading tf_nightly-2.9.0.dev20220323-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (509.8 MB)
     |████████████████████████████████| 509.8 MB 37 kB/s 
  Downloading tf_nightly-2.9.0.dev20220322-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (509.6 MB)
     |████████████████████████████████| 509.6 MB 28 kB/s 
  Downloading tf_nightly-2.9.0.dev20220321-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.6 MB)
     |████████████████████████████████| 499.6 MB 17 kB/s 
  Downloading tf_nightly-2.9.0.dev20220320-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.6 MB)
     |████████████████████████████████| 499.6 MB 16 kB/s 
  Downloading tf_nightly-2.9.0.dev20220319-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.6 MB)
     |████████████████████████████████| 499.6 MB 26 kB/s 
  Downloading tf_nightly-2.9.0.dev20220318-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (500.5 MB)
     |████████████████████████████████| 500.5 MB 32 kB/s 
  Downloading tf_nightly-2.9.0.dev20220316-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.7 MB)
     |████████████████████████████████| 499.7 MB 21 kB/s 
  Downloading tf_nightly-2.9.0.dev20220315-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.6 MB)
     |████████████████████████████████| 499.6 MB 16 kB/s 
  Downloading tf_nightly-2.9.0.dev20220314-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.4 MB)
     |████████████████████████████████| 499.4 MB 23 kB/s 
  Downloading tf_nightly-2.9.0.dev20220313-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.4 MB)
     |████████████████████████████████| 499.4 MB 28 kB/s 
  Downloading tf_nightly-2.9.0.dev20220312-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.5 MB)
     |████████████████████████████████| 499.5 MB 534 bytes/s 
  Downloading tf_nightly-2.9.0.dev20220311-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.4 MB)
     |████████████████████████████████| 499.4 MB 28 kB/s 
  Downloading tf_nightly-2.9.0.dev20220310-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.4 MB)
     |████████████████████████████████| 499.4 MB 13 kB/s 
  Downloading tf_nightly-2.9.0.dev20220309-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.5 MB)
     |████████████████████████████████| 499.5 MB 28 kB/s 
  Downloading tf_nightly-2.9.0.dev20220308-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.4 MB)
     |████████████████████████████████| 499.4 MB 3.3 kB/s 
  Downloading tf_nightly-2.9.0.dev20220307-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.3 MB)
     |████████████████████████████████| 499.3 MB 22 kB/s 
  Downloading tf_nightly-2.9.0.dev20220306-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.3 MB)
     |████████████████████████████████| 499.3 MB 20 kB/s 
  Downloading tf_nightly-2.9.0.dev20220305-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.3 MB)
     |████████████████████████████████| 499.3 MB 32 kB/s 
  Downloading tf_nightly-2.9.0.dev20220304-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.3 MB)
     |████████████████████████████████| 499.3 MB 12 kB/s 
  Downloading tf_nightly-2.9.0.dev20220303-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.2 MB)
     |████████████████████████████████| 499.2 MB 24 kB/s 
  Downloading tf_nightly-2.9.0.dev20220302-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.3 MB)
     |████████████████████████████████| 499.3 MB 2.2 kB/s 
  Downloading tf_nightly-2.9.0.dev20220301-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (499.2 MB)
     |████████████████████████████████| 499.2 MB 15 kB/s 
  Downloading tf_nightly-2.9.0.dev20220228-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.8 MB)
     |████████████████████████████████| 498.8 MB 1.4 kB/s 
  Downloading tf_nightly-2.9.0.dev20220227-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.8 MB)
     |████████████████████████████████| 498.8 MB 461 bytes/s 
  Downloading tf_nightly-2.9.0.dev20220226-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.8 MB)
     |████████████████████████████████| 498.8 MB 19 kB/s 
  Downloading tf_nightly-2.9.0.dev20220224-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.7 MB)
     |████████████████████████████████| 498.7 MB 5.1 kB/s 
  Downloading tf_nightly-2.9.0.dev20220223-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.7 MB)
     |████████████████████████████████| 498.7 MB 29 kB/s 
  Downloading tf_nightly-2.9.0.dev20220222-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.7 MB)
     |████████████████████████████████| 498.7 MB 11 kB/s 
  Downloading tf_nightly-2.9.0.dev20220221-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.7 MB)
     |████████████████████████████████| 498.7 MB 25 kB/s 
  Downloading tf_nightly-2.9.0.dev20220220-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.7 MB)
     |████████████████████████████████| 498.7 MB 23 kB/s 
  Downloading tf_nightly-2.9.0.dev20220219-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.7 MB)
     |████████████████████████████████| 498.7 MB 22 kB/s 
  Downloading tf_nightly-2.9.0.dev20220218-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (498.6 MB)
     |████████████████████████████████| 498.6 MB 16 kB/s 
  Downloading tf_nightly-2.9.0.dev20220216-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.8 MB)
     |████████████████████████████████| 497.8 MB 29 kB/s 
  Downloading tf_nightly-2.9.0.dev20220215-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.7 MB)
     |████████████████████████████████| 497.7 MB 31 kB/s 
  Downloading tf_nightly-2.9.0.dev20220214-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.7 MB)
     |████████████████████████████████| 497.7 MB 30 kB/s 
  Downloading tf_nightly-2.9.0.dev20220213-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.7 MB)
     |████████████████████████████████| 497.7 MB 3.4 kB/s 
  Downloading tf_nightly-2.9.0.dev20220212-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.7 MB)
     |████████████████████████████████| 497.7 MB 9.5 kB/s 
  Downloading tf_nightly-2.9.0.dev20220211-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.6 MB)
     |████████████████████████████████| 497.6 MB 24 kB/s 
  Downloading tf_nightly-2.9.0.dev20220210-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.6 MB)
     |████████████████████████████████| 497.6 MB 26 kB/s 
  Downloading tf_nightly-2.9.0.dev20220209-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.5 MB)
     |████████████████████████████████| 497.5 MB 25 kB/s 
  Downloading tf_nightly-2.9.0.dev20220208-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.3 MB)
     |████████████████████████████████| 497.3 MB 18 kB/s 
  Downloading tf_nightly-2.9.0.dev20220203-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (497.2 MB)
     |████████████████████████████████| 497.2 MB 3.5 kB/s 
  Downloading tf_nightly-2.9.0.dev20220202-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (496.8 MB)
     |████████████████████████████████| 496.8 MB 5.1 kB/s 
  Downloading tf_nightly-2.9.0.dev20220201-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (194.3 MB)
     |████████████████████████████████| 194.3 MB 22 kB/s 
INFO: pip is looking at multiple versions of tensorflow-hub to determine which version is compatible with other requirements. This could take a while.
Collecting tensorflow-hub>=0.8.0
  Downloading tensorflow_hub-0.12.0-py2.py3-none-any.whl (108 kB)
     |████████████████████████████████| 108 kB 56.1 MB/s 
  Downloading tensorflow_hub-0.11.0-py2.py3-none-any.whl (107 kB)
     |████████████████████████████████| 107 kB 56.2 MB/s 
  Downloading tensorflow_hub-0.10.0-py2.py3-none-any.whl (107 kB)
     |████████████████████████████████| 107 kB 53.8 MB/s 
  Downloading tensorflow_hub-0.9.0-py2.py3-none-any.whl (103 kB)
     |████████████████████████████████| 103 kB 56.3 MB/s 
  Downloading tensorflow_hub-0.8.0-py2.py3-none-any.whl (101 kB)
     |████████████████████████████████| 101 kB 10.4 MB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.8.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.9 MB)
     |████████████████████████████████| 4.9 MB 38.3 MB/s 
INFO: pip is looking at multiple versions of tensorflow-hub to determine which version is compatible with other requirements. This could take a while.
  Downloading tensorflow_text-2.7.3-cp37-cp37m-manylinux2010_x86_64.whl (4.9 MB)
     |████████████████████████████████| 4.9 MB 35.4 MB/s 
Collecting tensorflow
  Downloading https://us-python.pkg.dev/colab-wheels/public/tensorflow/tensorflow-2.7.3%2Bzzzcolab20220523111007-cp37-cp37m-linux_x86_64.whl
     | 671.4 MB 101.1 MB/s
Collecting tensorflow-estimator<2.8,~=2.7.0rc0
  Downloading tensorflow_estimator-2.7.0-py2.py3-none-any.whl (463 kB)
     |████████████████████████████████| 463 kB 43.0 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.7.3-cp37-cp37m-manylinux2010_x86_64.whl (495.4 MB)
     |████████████████████████████████| 495.4 MB 28 kB/s 
  Downloading https://us-python.pkg.dev/colab-wheels/public/tensorflow/tensorflow-2.7.2%2Bzzzcolab20220516114640-cp37-cp37m-linux_x86_64.whl
     \ 671.4 MB 357 kB/s
  Downloading tensorflow-2.7.2-cp37-cp37m-manylinux2010_x86_64.whl (495.4 MB)
     |████████████████████████████████| 495.4 MB 29 kB/s 
  Downloading tensorflow-2.7.1-cp37-cp37m-manylinux2010_x86_64.whl (495.0 MB)
     |████████████████████████████████| 495.0 MB 18 kB/s 
  Downloading https://us-python.pkg.dev/colab-wheels/public/tensorflow/tensorflow-2.7.0%2Bzzzcolab20220506150900-cp37-cp37m-linux_x86_64.whl
     \ 665.5 MB 69.2 MB/s
  Downloading tensorflow-2.7.0-cp37-cp37m-manylinux2010_x86_64.whl (489.6 MB)
     |████████████████████████████████| 489.6 MB 24 kB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.7.0-cp37-cp37m-manylinux2010_x86_64.whl (4.9 MB)
     |████████████████████████████████| 4.9 MB 31.2 MB/s 
  Downloading tensorflow_text-2.6.0-cp37-cp37m-manylinux1_x86_64.whl (4.4 MB)
     |████████████████████████████████| 4.4 MB 48.8 MB/s 
Collecting tensorflow
  Downloading https://us-python.pkg.dev/colab-wheels/public/tensorflow/tensorflow-2.6.5%2Bzzzcolab20220523104206-cp37-cp37m-linux_x86_64.whl
     / 570.3 MB 62.3 MB/s
  Downloading tensorflow-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl (464.2 MB)
     |████████████████████████████████| 464.2 MB 8.1 kB/s 
  Downloading https://us-python.pkg.dev/colab-wheels/public/tensorflow/tensorflow-2.6.4%2Bzzzcolab20220516125453-cp37-cp37m-linux_x86_64.whl
     - 570.3 MB 2.8 MB/s
  Downloading tensorflow-2.6.4-cp37-cp37m-manylinux2010_x86_64.whl (464.2 MB)
     |████████████████████████████████| 464.2 MB 17 kB/s 
  Downloading tensorflow-2.6.3-cp37-cp37m-manylinux2010_x86_64.whl (463.8 MB)
     |████████████████████████████████| 463.8 MB 32 kB/s 
  Downloading tensorflow-2.6.2-cp37-cp37m-manylinux2010_x86_64.whl (458.3 MB)
     |████████████████████████████████| 458.3 MB 13 kB/s 
  Downloading tensorflow-2.6.1-cp37-cp37m-manylinux2010_x86_64.whl (458.3 MB)
     |████████████████████████████████| 458.3 MB 13 kB/s 
Collecting tensorflow-estimator<2.7
  Downloading tensorflow_estimator-2.6.0-py2.py3-none-any.whl (462 kB)
     |████████████████████████████████| 462 kB 55.6 MB/s 
Collecting tensorflow
  Downloading https://us-python.pkg.dev/colab-wheels/public/tensorflow/tensorflow-2.6.0%2Bzzzcolab20220506153740-cp37-cp37m-linux_x86_64.whl
     | 564.4 MB 2.6 MB/s
  Downloading tensorflow-2.6.0-cp37-cp37m-manylinux2010_x86_64.whl (458.3 MB)
     |████████████████████████████████| 458.3 MB 11 kB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.5.0-cp37-cp37m-manylinux1_x86_64.whl (4.3 MB)
     |████████████████████████████████| 4.3 MB 41.3 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.5.3-cp37-cp37m-manylinux2010_x86_64.whl (460.3 MB)
     |████████████████████████████████| 460.3 MB 8.8 kB/s 
Collecting tensorflow-estimator<2.6.0,>=2.5.0
  Downloading tensorflow_estimator-2.5.0-py2.py3-none-any.whl (462 kB)
     |████████████████████████████████| 462 kB 52.2 MB/s 
Collecting grpcio<2.0,>=1.24.3
  Downloading grpcio-1.34.1-cp37-cp37m-manylinux2014_x86_64.whl (4.0 MB)
     |████████████████████████████████| 4.0 MB 48.9 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.5.2-cp37-cp37m-manylinux2010_x86_64.whl (454.4 MB)
     |████████████████████████████████| 454.4 MB 25 kB/s 
  Downloading tensorflow-2.5.1-cp37-cp37m-manylinux2010_x86_64.whl (454.4 MB)
     |████████████████████████████████| 454.4 MB 9.7 kB/s 
  Downloading tensorflow-2.5.0-cp37-cp37m-manylinux2010_x86_64.whl (454.3 MB)
     |████████████████████████████████| 454.3 MB 17 kB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.4.3-cp37-cp37m-manylinux1_x86_64.whl (3.4 MB)
     |████████████████████████████████| 3.4 MB 37.3 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.4.4-cp37-cp37m-manylinux2010_x86_64.whl (394.5 MB)
     |████████████████████████████████| 394.5 MB 41 kB/s 
  Downloading tensorflow-2.4.3-cp37-cp37m-manylinux2010_x86_64.whl (394.5 MB)
     |████████████████████████████████| 394.5 MB 353 bytes/s 
  Downloading tensorflow-2.4.2-cp37-cp37m-manylinux2010_x86_64.whl (394.5 MB)
     |████████████████████████████████| 394.5 MB 29 kB/s 
  Downloading tensorflow-2.4.1-cp37-cp37m-manylinux2010_x86_64.whl (394.3 MB)
     |████████████████████████████████| 394.3 MB 15 kB/s 
  Downloading tensorflow-2.4.0-cp37-cp37m-manylinux2010_x86_64.whl (394.7 MB)
     |████████████████████████████████| 394.7 MB 17 kB/s 
INFO: pip is looking at multiple versions of tensorflow-text to determine which version is compatible with other requirements. This could take a while.
Collecting tensorflow-text
  Downloading tensorflow_text-2.4.2-cp37-cp37m-manylinux1_x86_64.whl (3.4 MB)
     |████████████████████████████████| 3.4 MB 47.6 MB/s 
  Downloading tensorflow_text-2.4.1-cp37-cp37m-manylinux1_x86_64.whl (3.4 MB)
     |████████████████████████████████| 3.4 MB 38.7 MB/s 
  Downloading tensorflow_text-2.3.0-cp37-cp37m-manylinux1_x86_64.whl (2.6 MB)
     |████████████████████████████████| 2.6 MB 34.2 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.3.4-cp37-cp37m-manylinux2010_x86_64.whl (320.6 MB)
     |████████████████████████████████| 320.6 MB 22 kB/s 
  Downloading tensorflow-2.3.3-cp37-cp37m-manylinux2010_x86_64.whl (320.5 MB)
     |████████████████████████████████| 320.5 MB 14 kB/s 
  Downloading tensorflow-2.3.2-cp37-cp37m-manylinux2010_x86_64.whl (320.4 MB)
     |████████████████████████████████| 320.4 MB 15 kB/s 
  Downloading tensorflow-2.3.1-cp37-cp37m-manylinux2010_x86_64.whl (320.4 MB)
     |████████████████████████████████| 320.4 MB 23 kB/s 
  Downloading tensorflow-2.3.0-cp37-cp37m-manylinux2010_x86_64.whl (320.4 MB)
     |████████████████████████████████| 320.4 MB 41 kB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.2.1-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
     |████████████████████████████████| 3.0 MB 44.2 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.2.3-cp37-cp37m-manylinux2010_x86_64.whl (516.4 MB)
     |████████████████████████████████| 516.4 MB 18 kB/s 
Collecting tensorflow-estimator<2.3.0,>=2.2.0
  Downloading tensorflow_estimator-2.2.0-py2.py3-none-any.whl (454 kB)
     |████████████████████████████████| 454 kB 53.1 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.2.2-cp37-cp37m-manylinux2010_x86_64.whl (516.2 MB)
     |████████████████████████████████| 516.2 MB 23 kB/s 
  Downloading tensorflow-2.2.1-cp37-cp37m-manylinux2010_x86_64.whl (516.2 MB)
     |████████████████████████████████| 516.2 MB 37 kB/s 
  Downloading tensorflow-2.2.0-cp37-cp37m-manylinux2010_x86_64.whl (516.2 MB)
     |████████████████████████████████| 516.2 MB 4.9 kB/s 
Collecting gast>=0.2.1
  Downloading gast-0.3.3-py2.py3-none-any.whl (9.7 kB)
Collecting tensorflow-text
  Downloading tensorflow_text-2.2.0-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
     |████████████████████████████████| 3.0 MB 38.7 MB/s 
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. If you want to abort this run, you can press Ctrl + C to do so. To improve how pip performs, tell us what happened here: https://pip.pypa.io/surveys/backtracking
  Downloading tensorflow_text-2.1.1-cp37-cp37m-manylinux1_x86_64.whl (8.3 MB)
     |████████████████████████████████| 8.3 MB 9.8 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.1.4-cp37-cp37m-manylinux2010_x86_64.whl (422.0 MB)
     |████████████████████████████████| 422.0 MB 33 kB/s 
  Downloading tensorflow-2.1.3-cp37-cp37m-manylinux2010_x86_64.whl (421.9 MB)
     |████████████████████████████████| 421.9 MB 18 kB/s 
  Downloading tensorflow-2.1.2-cp37-cp37m-manylinux2010_x86_64.whl (421.8 MB)
     |████████████████████████████████| 421.8 MB 23 kB/s 
  Downloading tensorflow-2.1.1-cp37-cp37m-manylinux2010_x86_64.whl (421.8 MB)
     |████████████████████████████████| 421.8 MB 41 kB/s 
  Downloading tensorflow-2.1.0-cp37-cp37m-manylinux2010_x86_64.whl (421.8 MB)
     |████████████████████████████████| 421.8 MB 27 kB/s 
Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow->t5x==0.0.0) (0.8.1)
Collecting keras-applications>=1.0.8
  Downloading Keras_Applications-1.0.8-py3-none-any.whl (50 kB)
     |████████████████████████████████| 50 kB 5.3 MB/s 
Collecting gast>=0.2.1
  Downloading gast-0.2.2.tar.gz (10 kB)
Collecting tensorflow-text
  Downloading tensorflow_text-2.0.1-cp37-cp37m-manylinux1_x86_64.whl (9.1 MB)
     |████████████████████████████████| 9.1 MB 28.3 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.0.4-cp37-cp37m-manylinux2010_x86_64.whl (86.4 MB)
     |████████████████████████████████| 86.4 MB 66 kB/s 
Collecting tensorflow-estimator<2.1.0,>=2.0.0
  Downloading tensorflow_estimator-2.0.1-py2.py3-none-any.whl (449 kB)
     |████████████████████████████████| 449 kB 49.9 MB/s 
Collecting h5py>=2.9.0
  Downloading h5py-2.10.0-cp37-cp37m-manylinux1_x86_64.whl (2.9 MB)
     |████████████████████████████████| 2.9 MB 37.0 MB/s 
Collecting tensorflow
  Downloading tensorflow-2.0.3-cp37-cp37m-manylinux2010_x86_64.whl (86.4 MB)
     |████████████████████████████████| 86.4 MB 57 kB/s 
  Downloading tensorflow-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (86.4 MB)
     |████████████████████████████████| 86.4 MB 92 kB/s 
  Downloading tensorflow-2.0.1-cp37-cp37m-manylinux2010_x86_64.whl (86.3 MB)
     |████████████████████████████████| 86.3 MB 92 kB/s 
  Downloading tensorflow-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl (86.3 MB)
     |████████████████████████████████| 86.3 MB 45 kB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.0.0-cp37-cp37m-manylinux1_x86_64.whl (9.1 MB)
     |████████████████████████████████| 9.1 MB 30.0 MB/s 
  Downloading tensorflow_text-1.15.1-cp37-cp37m-manylinux1_x86_64.whl (9.1 MB)
     |████████████████████████████████| 9.1 MB 30.6 MB/s 
Collecting tensorflow
  Downloading tensorflow-1.15.5-cp37-cp37m-manylinux2010_x86_64.whl (110.5 MB)
     |████████████████████████████████| 110.5 MB 1.4 MB/s 
Collecting tensorflow-estimator==1.15.1
  Downloading tensorflow_estimator-1.15.1-py2.py3-none-any.whl (503 kB)
     |████████████████████████████████| 503 kB 47.8 MB/s 
Collecting tensorflow
  Downloading tensorflow-1.15.4-cp37-cp37m-manylinux2010_x86_64.whl (110.5 MB)
     |████████████████████████████████| 110.5 MB 1.3 MB/s 
  Downloading tensorflow-1.15.3-cp37-cp37m-manylinux2010_x86_64.whl (110.5 MB)
     |████████████████████████████████| 110.5 MB 21 kB/s 
  Downloading tensorflow-1.15.2-cp37-cp37m-manylinux2010_x86_64.whl (110.5 MB)
     |████████████████████████████████| 110.5 MB 40 kB/s 
  Downloading tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl (412.3 MB)
     |████████████████████████████████| 412.3 MB 25 kB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-1.15.0-cp37-cp37m-manylinux1_x86_64.whl (9.1 MB)
     |████████████████████████████████| 9.1 MB 31.9 MB/s 
  Downloading tensorflow_text-0.1.0-cp37-cp37m-manylinux1_x86_64.whl (6.4 MB)
     |████████████████████████████████| 6.4 MB 18.6 MB/s 
Collecting tensorflow
  Downloading tensorflow-1.14.0-cp37-cp37m-manylinux1_x86_64.whl (109.3 MB)
     |████████████████████████████████| 109.3 MB 1.2 MB/s 
Collecting tensorflow-estimator<1.15.0rc0,>=1.14.0rc0
  Downloading tensorflow_estimator-1.14.0-py2.py3-none-any.whl (488 kB)
     |████████████████████████████████| 488 kB 48.4 MB/s 
INFO: pip is looking at multiple versions of googleapis-common-protos to determine which version is compatible with other requirements. This could take a while.
Collecting googleapis-common-protos<2,>=1.52.0
  Downloading googleapis_common_protos-1.56.1-py2.py3-none-any.whl (211 kB)
     |████████████████████████████████| 211 kB 58.3 MB/s 
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. If you want to abort this run, you can press Ctrl + C to do so. To improve how pip performs, tell us what happened here: https://pip.pypa.io/surveys/backtracking
  Downloading googleapis_common_protos-1.56.0-py2.py3-none-any.whl (241 kB)
     |████████████████████████████████| 241 kB 37.0 MB/s 
  Downloading googleapis_common_protos-1.55.0-py2.py3-none-any.whl (212 kB)
     |████████████████████████████████| 212 kB 42.4 MB/s 
  Downloading googleapis_common_protos-1.54.0-py2.py3-none-any.whl (207 kB)
     |████████████████████████████████| 207 kB 51.2 MB/s 
  Downloading googleapis_common_protos-1.53.0-py2.py3-none-any.whl (198 kB)
     |████████████████████████████████| 198 kB 35.1 MB/s 
  Downloading googleapis_common_protos-1.52.0-py2.py3-none-any.whl (100 kB)
     |████████████████████████████████| 100 kB 6.4 MB/s 
INFO: pip is looking at multiple versions of tensorflow-metadata to determine which version is compatible with other requirements. This could take a while.
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.8.0-py3-none-any.whl (50 kB)
     |████████████████████████████████| 50 kB 5.4 MB/s 
INFO: pip is looking at multiple versions of googleapis-common-protos to determine which version is compatible with other requirements. This could take a while.
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. If you want to abort this run, you can press Ctrl + C to do so. To improve how pip performs, tell us what happened here: https://pip.pypa.io/surveys/backtracking
  Downloading tensorflow_metadata-1.7.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 3.9 MB/s 
  Downloading tensorflow_metadata-1.6.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 3.5 MB/s 
  Downloading tensorflow_metadata-1.5.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 4.5 MB/s 
Collecting absl-py
  Downloading absl_py-0.12.0-py3-none-any.whl (129 kB)
     |████████████████████████████████| 129 kB 20.7 MB/s 
INFO: pip is looking at multiple versions of absl-py to determine which version is compatible with other requirements. This could take a while.
  Downloading absl_py-0.11.0-py3-none-any.whl (127 kB)
     |████████████████████████████████| 127 kB 49.8 MB/s 
  Downloading absl_py-0.10.0-py3-none-any.whl (127 kB)
     |████████████████████████████████| 127 kB 43.8 MB/s 
  Downloading absl-py-0.9.0.tar.gz (104 kB)
     |████████████████████████████████| 104 kB 9.6 MB/s 
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.4.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 4.0 MB/s 
INFO: pip is looking at multiple versions of absl-py to determine which version is compatible with other requirements. This could take a while.
  Downloading tensorflow_metadata-1.2.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 3.3 MB/s 
  Downloading tensorflow_metadata-1.1.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 3.9 MB/s 
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. If you want to abort this run, you can press Ctrl + C to do so. To improve how pip performs, tell us what happened here: https://pip.pypa.io/surveys/backtracking
INFO: pip is looking at multiple versions of tensorflow-metadata to determine which version is compatible with other requirements. This could take a while.
  Downloading tensorflow_metadata-1.0.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 3.5 MB/s 
  Downloading tensorflow_metadata-0.30.0-py3-none-any.whl (47 kB)
     |████████████████████████████████| 47 kB 3.7 MB/s 
  Downloading tensorflow_metadata-0.29.0-py3-none-any.whl (47 kB)
     |████████████████████████████████| 47 kB 4.0 MB/s 
  Downloading tensorflow_metadata-0.28.0-py3-none-any.whl (47 kB)
     |████████████████████████████████| 47 kB 3.3 MB/s 
ERROR: Operation cancelled by user

Error when saving the checkpoint

Hi! I met an issue when saving the checkpoint. I also commented my issue below #446

This issue occurred when I ended training for 100 steps and saved the checkpoint to the absolute path '/gpfsnyu/scratch/kf2395/jukemir_t5/pretrain/'

My Tensorstore is 0.1.19.

As this solution mentioned, the relative path may cause some problems, so I changed it to the absolute path.

When I used the relative path at the beginning, I got the similar error as this comment. The error message:

ValueError: Error opening "zarr" driver: Error reading local file "./pretrain_model/checkpoint_5000.tmp-1650694933/state.param_states.decoder.decoder_norm.scale.v/.zarray": Invalid key: "./pretrain_model/checkpoint_5000.tmp-1650694933/state.param_states.decoder.decoder_norm.scale.v/.zarray" In call to configurable 'train' (<function train at 0x7fa6818348c0>))

Then I changed the path to the absolute path and the issue above was solved. But a new issue occurred.

ValueError: Error opening "zarr" driver: Error writing local file "/gpfsnyu/scratch/kf2395/jukemir_t5/pretrain/checkpoint_100.tmp-1650949633/state.param_states.decoder.layers_0.pre_cross_attention_layer_norm.scale.v/.zarray": Failed to acquire lock on file: /gpfsnyu/scratch/kf2395/jukemir_t5/pretrain/checkpoint_100.tmp-1650949633/state.param_states.decoder.layers_0.pre_cross_attention_layer_norm.scale.v/.zarray.__lock [OS error: Invalid argument] In call to configurable 'train' (<function train at 0x7f651e1e78c0>)

I tried to delete all the files in '/gpfsnyu/scratch/kf2395/jukemir_t5/pretrain/' and trained again. But this issue still existed.

The detailed error message:

I0426 13:05:36.808195 140074531202880 train.py:516] Epoch 0 of 10000
I0426 13:05:36.808564 140055117031168 logging_writer.py:48] [0] collection=train timing/compilation_seconds=160.272345
I0426 13:05:36.828166 140074531202880 train.py:522] BEGIN Train loop.
I0426 13:05:36.828350 140074531202880 train.py:527] Training for 100 steps.
I0426 13:05:36.833504 140074531202880 trainer.py:517] Training: step 0
I0426 13:05:47.585027 140074531202880 trainer.py:517] Training: step 12
I0426 13:05:58.556400 140074531202880 trainer.py:517] Training: step 23
I0426 13:06:09.237899 140074531202880 trainer.py:517] Training: step 34
I0426 13:06:19.734536 140074531202880 trainer.py:517] Training: step 45
I0426 13:06:30.668152 140074531202880 trainer.py:517] Training: step 56
I0426 13:06:41.496444 140074531202880 trainer.py:517] Training: step 67
I0426 13:06:52.412244 140074531202880 trainer.py:517] Training: step 78
I0426 13:07:03.236425 140074531202880 trainer.py:517] Training: step 89
I0426 13:07:13.692245 140074531202880 train.py:550] END Train loop.
I0426 13:07:13.727353 140055117031168 logging_writer.py:48] [100] collection=train accuracy=0.12926435470581055, cross_ent_loss=3456.254063, cross_ent_loss_per_all_target_tokens=0.337525, learning_rate=0.001000000280328095, learning_rate/current=0.0010000000474974513, loss=3460.679688, loss_per_all_target_tokens=0.337957, loss_per_nonpadding_target_token=5.071336, nonpadding_fraction=0.066641, timing/seconds=96.861853, timing/seqs=1000, timing/seqs_per_second=10.323982, timing/seqs_per_second_per_core=10.323982, timing/steps_per_second=1.032398, timing/target_tokens_per_second=10571.757297, timing/target_tokens_per_second_per_core=10571.757297, z_loss=4.426097, z_loss_per_all_target_tokens=0.000432
I0426 13:07:13.728666 140074531202880 train.py:565] Saving checkpoint.
I0426 13:07:13.730171 140074531202880 checkpoints.py:631] Saving checkpoint for step 100 to /gpfsnyu/scratch/kf2395/jukemir_t5/pretrain/checkpoint_100.tmp-1650949633
Traceback (most recent call last):
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/train.py", line 663, in
gin_utils.run(main)
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/gin_utils.py", line 107, in run
flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/train.py", line 641, in main
_main(argv)
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/train.py", line 661, in _main
train_using_gin()
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/site-packages/gin/config.py", line 1605, in gin_wrapper
utils.augment_exception_message_and_reraise(e, err_str)
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
raise proxy.with_traceback(exception.traceback) from None
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/site-packages/gin/config.py", line 1582, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/train.py", line 568, in train
checkpoint_cfg.save.state_transformation_fns)
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/checkpoints.py", line 639, in save
tmp_dir, train_state, concurrent_gb, state_transformation_fns)
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/checkpoints.py", line 806, in _write_state_to_tensorstore
written_state_dict = _run_future_tree(future_written_state)
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/checkpoints.py", line 167, in _run_future_tree
leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
File "/gpfsnyu/scratch/kf2395/.cache/env/tf2-gpu-py3.7/lib/python3.7/asyncio/base_events.py", line 587, in run_until_complete
return future.result()
File "/gpfsnyu/scratch/kf2395/jukemir_t5/t5x/checkpoints.py", line 770, in _write_array
'limit': 128
ValueError: Error opening "zarr" driver: Error writing local file "/gpfsnyu/scratch/kf2395/jukemir_t5/pretrain/checkpoint_100.tmp-1650949633/state.param_states.decoder.layers_0.pre_cross_attention_layer_norm.scale.v/.zarray": Failed to acquire lock on file: /gpfsnyu/scratch/kf2395/jukemir_t5/pretrain/checkpoint_100.tmp-1650949633/state.param_states.decoder.layers_0.pre_cross_attention_layer_norm.scale.v/.zarray.__lock [OS error: Invalid argument]
In call to configurable 'train' (<function train at 0x7f651e1e78c0>)

Thank you for your kindly help!

how to decide on how many steps to train the model given a custom dataset.

I have been pretraining a T5_1_1 base model locally on GPUs. i have currently reduced the batch_size from 256 to 64 to support training on my GPUs. However i needed some advice on how to decide on how many steps should i train my model, given that i am training on a custom dataset and i have reduced the batch_size.

currently after going through the documentation, i am using the default 1000000 steps to train my model.

Out of RAM when training T5x on WMT14 En-De with Colab's TPU

Hi,

I wanted to check if I can train and evaluate T5x Base on Colab's TPU.

When I try it in this Colab notebook, the session crashes after requesting more RAM than available (12 GB).

The last information log I receive is

trainer.py:472] Training: step 0

And the last warning I get is:

/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py:800: UserWarning: Some donated buffers were not usable: s32[], f32[768]{0}, f32[768]{0}, f32[768]{0}, [...]", ".join(unused_donations)))

Do you know if this is some Jax / TPU related issue, and whether one can use Colab's TPU / GPU to train and eval T5x?

Thank you for your time,

Leo

SeqIO evaluation hangs on final batch when dataset is padded

Hi all,

During SeqIO inference evaluation, the data pipeline is hanging for a long period of time on the final batch of data only when the dataset is padded. Sometimes for hours depending on the dataset size. This relates to these lines of code.

This is not an issue when the dataset is not padded, i.e. when the dataset length is divisible by batch size.

I am guessing it is the pad_ds on this line that is very slow to process, but I am not sure why:

      pad_ds = ds.take(1).map(lambda i, x: (np.int64(-1), x)).repeat(
          dataset_pad_amt)

Any idea how we could speed up the data pipeline when padding?

Fatal Python error: Segmentation fault, when training t5x-XXL on a TPU Pod v3-32

Hi,

I was able to train and infer prompt tuning with t5x-XXL on a TPU Pod v3-32 for my custom task defined from a TSV file, but I am seeing now an error and can't understand it.

I follow the instructions from prompt tuning to train and infer Prompt on a Pod Slice, except that the last libtpu_release gives an error TPUEmbeddingEngineState_Create not available in this library. so I install the release from February 15, 2022.

I run the following script


MODEL_DIR=${1:-${MODEL_DIR}}
TFDS_DATA_DIR=${2:-${TFDS_DATA_DIR}}

if [ -z ${MODEL_DIR} ] || [ -z ${TFDS_DATA_DIR} ]; then
                  echo "usage: ./rec_sys.sh gs://your-bucket/path/to/model_dir gs://your-bucket/path/to/tfds/cache"
                              exit 1
fi

T5X_DIR="`python3 -m prompt_tuning.scripts.find_module t5x`/.."
FLAXFORMER_DIR="`python3 -m prompt_tuning.scripts.find_module flaxformer`/.."
PROMPT_DIR="`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.."
echo "Searching for gin configs in:"
echo "- ${T5X_DIR}"
echo "- ${FLAXFORMER_DIR}"
echo "- ${PROMPT_DIR}"
echo "============================="
PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl/checkpoint_1100000"

python3 -m t5x.train \
                  --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \
                  --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" \
                  --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" \
                  --gin.MODEL_DIR="'${MODEL_DIR}'" \
                  --gin.BATCH_SIZE="16" \
                  --gin.MIXTURE_OR_TASK_NAME="'yelp'" \
                  --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" \
                  --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" \
                  --gin.USE_CACHED_TASKS="False" \
                  --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \
		  --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" \
                  --gin.TRAIN_STEPS="1_100_010" \

and get the following errors:
First:


tensorstore/internal/oauth2/google_auth_provider.cc:163: Credentials file not found. NOT_FOUND: $GOOGLE_APPLICATION_CREDENTIALS is not set or corrupt. 

tensorstore/internal/oauth2/google_auth_provider.cc:168: Credentials file not found. NOT_FOUND: Could not find the credentials file in the standard gcloud location [/home/leojlaugier/.config/gcloud/application_default_credentials.json] 

tensorstore/internal/oauth2/google_auth_provider.cc:203: Running on GCE, using GCE Auth Provider 

Fatal Python error: Segmentation fault 

 

 

Thread 0x00007f3304539c40 (most recent call first): 

  File "/usr/lib/python3.8/selectors.py", line 468 in select 

  File "/usr/lib/python3.8/asyncio/base_events.py", line 1823 in _run_once 

  File "/usr/lib/python3.8/asyncio/base_events.py", line 570 in run_forever 

  File "/usr/lib/python3.8/asyncio/base_events.py", line 603 in run_until_complete 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 160 in _run_future_tree 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 913 in _read_state_from_tensorstore 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 860 in restore 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 455 in _restore_path 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 466 in from_checkpoints 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 507 in from_checkpoint 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 522 in from_checkpoint_or_scratch 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 320 in train 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/gin/config.py", line 1582 in gin_wrapper 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 623 in _main 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 605 in main 

  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251 in _run_main 

  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303 in run 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/gin_utils.py", line 105 in run 

  File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 625 in <module> 

  File "/usr/lib/python3.8/runpy.py", line 87 in _run_code 

  File "/usr/lib/python3.8/runpy.py", line 194 in _run_module_as_main 

https://symbolize.stripped_domain/r/?trace=7f330498e18b,7f330498e20f,6&map= 

*** SIGSEGV (@0x7d100002a60), see gl__________41#s15 received by PID 10848 (TID 12157) on cpu 19; stack trace: *** 

PC: @     0x7f330498e18b  (unknown)  raise 

    @     0x7f32fb6ea1fa        992  (unknown) 

    @     0x7f330498e210  (unknown)  (unknown) 

    @                0x7  (unknown)  (unknown) 

https://symbolize.stripped_domain/r/?trace=7f330498e18b,7f32fb6ea1f9,7f330498e20f,6&map=55976a7e1de583f3a9544af1c86ac940:7f32ed01c000-7f32fba50d80 

E0310 16:51:25.580514   12157 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked. 

E0310 16:51:25.580525   12157 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start. 

E0310 16:51:25.580535   12157 client.cc:221] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec. 

E0310 16:51:25.580557   12157 coredump_hook.cc:473] RAW: Sending fingerprint to remote end. 

E0310 16:51:25.580562   12157 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket 

E0310 16:51:25.580565   12157 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running? 

E0310 16:51:25.580569   12157 coredump_hook.cc:550] RAW: Discarding core. 

And later

^[[6~^[[6~I0310 16:54:01.224571 140510105828416 train.py:456] Epoch 1100 of 1101
I0310 16:54:01.224764 140510105828416 train.py:462] BEGIN Train loop.
I0310 16:54:01.224818 140510105828416 train.py:467] Training for 10 steps.
I0310 16:54:01.226046 140497868457728 logging_writer.py:48] [1100000] collection=train timing/compilation_seconds=87.301567
I0310 16:54:01.230673 140510105828416 trainer.py:491] Training: step 1100000
I0310 16:54:01.635557 140510105828416 train.py:490] END Train loop.
./train_yelp_xxl.sh: line 34: 12024 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --gin.BATCH_SIZE="16" --gin.MIXTURE_OR_TASK_NAME="'yelp'" --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.USE_CACHED_TASKS="False" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_100_010"
##### Command execution on worker 1 failed with return code 134. Continuing.
./train_yelp_xxl.sh: line 34: 10848 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --gin.BATCH_SIZE="16" --gin.MIXTURE_OR_TASK_NAME="'yelp'" --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.USE_CACHED_TASKS="False" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_100_010"
##### Command execution on worker 0 failed with return code 134. Continuing.
./train_yelp_xxl.sh: line 34: 11607 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --gin.BATCH_SIZE="16" --gin.MIXTURE_OR_TASK_NAME="'yelp'" --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.USE_CACHED_TASKS="False" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_100_010"
##### Command execution on worker 3 failed with return code 134. Continuing.

Then the run freezes. I might be missing something obvious but I think I haven't changed anything but the data since the last time I was able to train and infer with prompt tuning. Moreover, I was able to train on the same train data but problems arose when I tried to infer.
Therefore, I'm asking if you could help me understand the issue.

Thanks in advance for your time.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.