Git Product home page Git Product logo

seqio's Introduction

SeqIO

Task-based datasets, preprocessing, and evaluation for sequence models

Go to SeqIO ReadTheDocs Documentation Page.

Overview

SeqIO is a library for processing sequential data to be fed into downstream sequence models. It uses tf.data.Dataset to create scalable data pipelines but requires minimal use of TensorFlow. In particular, with one line of code, the returned dataset can be transformed to a numpy iterator and hence it is fully compatible with other frameworks such as JAX or PyTorch.

SeqIO assumes that the dataset is a sequence. Modalities such as text or audio are naturally supported. Images are supported as long as they are represented as sequences (e.g., Image GPT).

SeqIO is a refactor of the t5.data library used (in conjunction with the Mesh Tensorflow Transformer implementation) to train the T5 models introduced in Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.

If you have used t5.data in the past and want to know how SeqIO differs, please read this section.

Installation

From Pypi

pip install seqio

From Source

git clone https://github.com/google/seqio.git
cd seqio
pip install -e .

Usage Tutorial

At a high level, we use SeqIO with the following steps.

  1. Define a Task (and optionally a Mixture).

  2. Define (or use an existing) a FeatureConverter based on the model architecture.

  3. Use the top-level function seqio.get_dataset to obtain the tf.data.Dataset instance.

We will look at each of these steps in detail.

Defining a Task

The most important class in SeqIO is the Task. It is an abstraction that combines:

  • a raw data source
  • one or more preprocessing steps
  • a vocabulary to tokenize/detokenize each preprocessed feature for the model
  • a postprocessor to convert detokenized model outputs into a format for evaluation
  • one or more metrics to evaluate with

Oftentimes a Task lines up with a common benchmark. In this tutorial, we use WMT 19 English-German machine translation task. In the end, our Task will look like this:

seqio.TaskRegistry.add(
    "wmt19_ende",
    seqio.TfdsDataSource(tfds_name="wmt19_translate/de-en:1.0.0"),
    preprocessors=[
        functools.partial(
            translate, source_language='en', target_language='de'),
        seqio.preprocessors.tokenize, seqio.preprocessors.append_eos
    ],
    output_features={
        'inputs':
            seqio.Feature(
                seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),
                add_eos=False,
                dtype=tf.int32),
        'targets':
            seqio.Feature(
                seqio.SentencePieceVocabulary('/path/to/targets/vocab'),
                add_eos=True,
                dtype=tf.int32),
    },
    metric_fns=[bleu])

We typically add the Task to the global registry when we define it (as shown above) to make it easier to use with model configs and flags. Thus, it must have a unique string name ("wmt19_ende" in this case). Note, however, that you may also instantiate a seqio.Task directly without adding it to the registry, if desired.

We'll now break down each part of the task definition.

Data Source

Data sources are the first step in your pipeline, providing a way to load raw data in many formats as a tf.data.Dataset. All data sources are subclasses of the DataSource base class and are defined in dataset_providers.

Existing implementations include:

  • TfdsDataSource for loading examples from TensorFlow Datasets.
  • TextLineDataSource for loading examples from text files (e.g., tsv).
  • TFExampleDataSource for loading tf.train.Example protos from a file (e.g. a TFRecord file.)
  • FunctionDataSource for providing an custom function that returns a tf.data.Dataset.

In our example, we are using the TfdsDataSource. We specify the name of the WMT dataset in TFDS ("wmt19_translate"), the specific config for the language pair that excludes the context for the open domain setting ("de-en"), and the version number ("1.0.0").

Output Features

The output_features field expects a dictionary that maps string feature names to seqio.Feature objects. This defines what the Task is expected to produce in its output examples. The output examples may contain additional fields, but they must contain these fields in the specified format or exceptions will be raised.

Each Feature includes:

  • A vocabulary, which must subclass seqio.Vocabulary, to specify how the feature can be tokenized and detokenized. You may use seqio.PassThroughVocabulary if tokenization is not necessary.
  • add_eos, which specifies whether the feature should end with the vocabulary's EOS token.
  • The output dtype which must be a tf.dtypes.DType.

Note: specifying these options on Feature does not by itself ensure the proper transformations are applied -- you must also include the necessary preprocessors.

The tasks used in T5 all produce "inputs" and "targets" features to be consumed by the text-to-text model. For a decoder-only language model, only a single feature (e.g., "targets") would be necessary. Nevertheless, SeqIO is flexible enough to generate arbitrary output features what will be converted into model features by the FeatureConverter later in the pipeline.

Preprocessors

Preprocessors are functions that transform one tf.data.Dataset into a new tf.data.Dataset. Typically this involves executing a map over the given dataset. The preprocessors provided to the Task will be executed sequentially.

As an example, let's look at the previously undefined translate from the "wmt19_ende" example above.

def translate(dataset: tf.data.Dataset,
              source_language: str,
              target_language: str) -> tf.data.Dataset:
  def _translate(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
    """Convert a translation example to a text2text pair.

    For example, say the dataset returns examples of this format:
      {'de': 'Das ist gut.', 'en': 'That is good.'}
    If source_language = 'de', target_language = 'en', then the outputs will
    have the format:
      {'inputs': 'translate de to en: Das ist gut.',
      'targets': 'That is good.'}

    Args:
      ex: an example to process.
      source_language: source language code (e.g. 'en') to translate from.
      target_language: target language code (e.g. 'de') to translate to.

    Returns:
      A preprocessed example with the format listed above.
    """
    src_str = f'translate {source_language}'
    tgt_str = f' to {target_language}: '
    return {
        'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),
        'targets': ex[target_language],
    }

  return dataset.map(_translate,
                     num_parallel_calls=tf.data.experimental.AUTOTUNE)

The TFDS dataset provides the dataset where each example has the form: {'de': 'Das ist gut.', 'en': 'That is good.'}. We convert this to "inputs" and "targets" with the appropriate prompt to inform the model of the task.

A few important notes:

  1. When instantiating a Task, the preprocessor functions can have the following arguments: dataset, output_features, and sequence_length. The first (positional) dataset argument is always required. If an argument named output_features is provided, the output feature mapping will be passed to the preprocessor. If sequence_length is provided, a mapping from feature name to its maximum final sequence length (provided by the caller) will be passed -- any sequences that are too long after preprocessing will be automatically truncated. If a preprocessor function does have other arguments, they must have default values or be bound (e.g., with functools.partial as used in translate) before instantiating the Task.

  2. Mapping functions operate on and return tf.Tensors using TensorFlow operations. This is more flexible than it may sound:

    • Automatic AutoGraph conversion allow you to write python control flow in your transformations.
    • tf.experimental.numpy provides a numpy interface.
    • tf.py_function allows you to wrap arbitrary Python code. Note: tf.data pipelines using this function can only be run in the python process where they were defined, and performance is limited by the python GIL.

    See tf.data.Dataset documentation for more details.

  3. When calling map, it is important to always set num_parallel_calls=tf.data.experimental.AUTOTUNE to avoid creating a bottleneck. The seqio.map_over_dataset decorator helps enforce this as follows.

    @seqio.map_over_dataset
    def translate(ex: Mapping[str, tf.Tensor],
                  source_language: str,
                  target_language: str) -> Mapping[str, tf.Tensor]:
      """Convert a translation dataset to a text2text pair.
    
      For example, say the dataset returns examples of this format:
        {'de': 'Das ist gut.', 'en': 'That is good.'}
      If source_language = 'de', target_language = 'en', then the outputs will
      have the format:
        {'inputs': 'translate German to English: Das ist gut.',
        'targets': 'That is good.'}
    
      Args:
        ex: an example to process.
        source_language: source language code (e.g. 'en') to translate from.
        target_language: target language code (e.g. 'de') to translate to.
    
      Returns:
        A preprocessed example with the format listed above.
      """
      src_str = f'translate {source_language}'
      tgt_str = f' to {target_language}: '
      return {
          'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),
          'targets': ex[target_language],
      }

    Note that translate takes as input an individual example. Then seqio.map_over_dataset decorates it to a function that takes in a tf.data.Dataset instance.

  4. Stochastic operations must be stateless if deterministic pipelines are needed. To get (optionally deterministic) seeds for these operations, use the seqio.map_over_dataset(num_seeds=n) decorator. For example:

    def random_chunk(
      dataset: tf.data.Dataset,
      sequence_length: Mapping[str, int]
    ) -> tf.data.Dataset:
    """Takes a random chunk out of each feature with size `sequence_length`."""
    
      @seqio.map_over_dataset(num_seeds=1)
      def take_chunk(
          ex: Mapping[str, tf.Tensor],
          seed
      ) -> Mapping[str, tf.Tensor]:
        new_ex = {}
        for k, v in ex.items():
          if k in sequence_length:
            length = sequence_length[k]
            start_idx = tf.random.stateless_uniform(
               (), seed, 0, tf.size(v) - (length + 1))
            new_ex[k] = v[start_idx:start_idx+length]
          else:
            new_ex[k] = v
        return new_ex
    
    return take_chunk(dataset)

    If num_seeds > 1, the arg will instead be called seeds and will contain a sequence of seeds.

In our "wmt_19_ende" task, we also use the predefined preprocessors seqio.preprocessors.tokenize and seqio.preprocessors.append_eos. The former uses each Feature.vocabulary to tokenize it, and the the latter appends Feature.vocabulary.eos_id to the feature if the Feature.add_eos is True. See preprocessors.py for their implementations and other useful preprocessors.

Postprocessor

During evaluation, the model outputs are first detokenized using the output feature vocabulary. Before passing these predictions to the metric functions, they can be run through a Python postprocessing function, alongside the full input example. Similarly, the raw targets are run through this function before being passed to the metrics. Since the postprocess function is used on both the model output and the targets, it is passed an is_target boolean in case the behavior should be different. It is also passed the fully preprocessed example, including fields that were excluded from output_features.

For the "wmt19_ende", we don't need any postprocessors. See "trivia_qa_open" task in the Advanced Postprocessing Task for an example postprocessor.

Metrics

Metrics are functions that are passed (by the Evaluator) the fully-materialized list of postprocessed model outputs (or scores) and targets and return a mapping from string names to MetricValue objects containing their values. These are most commonly floating-point scalars, but may also be text, images, audio, histograms, etc (see metrics.py for the full list).

The first argument of a metric function must always be called targets. If the second argument of a metric function is called predictions, it will be passed the decoded and detokenized model prediction. If it is called scores, it will be passed a list of log-likelihood scores for each example.

If multiple metric functions are provided, they will all be used and their returned mappings merged.

Prediction Metrics

Prediction metrics are computed using the postprocessed targets and model outputs (predictions). The args must be named targets and predictions.

Let's look at the metric function used for "wmt19_ende" task. A standard metric for the translation task is BLEU and we use sacrebleu implementation.

def bleu(targets: Sequence[str], predictions: Sequence[str]):
  """Computes BLEU score.

  Args:
    targets: list of strings or list of list of strings if multiple references
      are present.
    predictions: list of strings

  Returns:
    bleu_score across all targets and predictions
  """
  if isinstance(targets[0], list):
    targets = [[x for x in target] for target in targets]
  else:
    # Need to wrap targets in another list for corpus_bleu.
    targets = [targets]

  bleu_score = sacrebleu.corpus_bleu(predictions, targets,
                                     smooth_method="exp",
                                     smooth_value=0.0,
                                     force=False,
                                     lowercase=False,
                                     tokenize="intl",
                                     use_effective_order=False)
  return {"bleu": bleu_score.score}
Score Metrics

Score metrics are computed using the postprocessed targets and their log-likelihood scores according to the model. The args must be named targets and scores.

def perplexity(targets: Sequence[str], scores: Sequence[int]):
  return {
    "perplexity": seqio.metrics.Scalar(np.exp(np.mean(scores)))
  }

Defining a Mixture

Once you have multiple Tasks added to the TaskRegistry, you can define Mixtures that will combine the examples from them according to some specified rate. Examples will then be sampled from each task in proportion to its rate.

As an example, Multilingual T5 uses a Mixture of per-language Tasks with tail languages up-weighted in the mixture.

There are 3 ways to specify the tasks and their rates:

  1. Provide a rate along with each task's name (rates are normalized before sampling). In this example, the rates provided are units of the final mixture that come from the component tasks. Here, 1/(1+7) of the final mixture will come from "task1".

    seqio.MixtureRegistry.add(
      "mix1",
      [("task1", 1), ("task2", 7)]
    )
  2. Provide a constant default rate for some or all tasks, which will be used when only the name is provided. The example below will produce identical mixing rates as the previous one.

    seqio.MixtureRegistry.add(
      "mix1",
      [("task1", 0.5), "task2"],
      default_rate=3.5
    )
  3. Provide a function that generates the rate for each task at runtime. The example below uses the provided seqio.mixing_rate_num_examples, which uses the number of examples (computed during offline caching) as the rate for each task.

    seqio.MixtureRegistry.add(
      "mix2",
      ["task1", "task2"],
      default_rate=seqio.mixing_rate_num_examples
    )

You can also include Mixtures in your Mixture! For example, the following task would contain 1/24 (from "mix1") + 1/3 "task1", 7/24 (from "mix1") of "task2", and 1/3 "task3".

seqio.MixtureRegistry.add(
  "mix3",
  ["mix1", "task1", "task3"],
  default_rate=1
)

If sampling without replacement is important for your task, you can achieve that by using either deterministic tasks or using dataset checkpointing (and not running more than an epoch) for a non-deterministic task. Otherwise, the mixture may sample with replacement.

Getting a Preprocessed Dataset

Now that your Task (and/or Mixture) is defined, its primary functionality is to use it to generate a dataset.

You may first need to use seqio.get_mixture_or_task(mixture_or_task_name) to access your dataset provider from the registry.

After that, you can call get_dataset to build the tf.data.Dataset. For example:

dataset = seqio.get_mixture_or_task("mix1").get_dataset(
    sequence_length={"inputs": 256, "targets": 128},
    split="train",
    shuffle=True,
    num_epochs=1,
    shard_info=seqio.ShardInfo(index=0, num_shards=10),
    use_cached=False,
    seed=42
)

# Print the first 5 examples.
for _, ex in zip(range(5), dataset.as_numpy_iterator()):
  print(ex)

Some notes on a few of the arguments:

  • sequence_length: An optional mapping from feature name to maximum length. Will be passed to the preprocessors with a sequence_length argument. If not None, the final example features will be truncated if they exceed the specified length. Note that this value may be required to be set if any of the preprocessors use the sequence_length argument and do not handle the None case.
  • num_epochs: The number of times to repeat the source dataset. Preprocessing will be re-applied with new seeds to enable new samples from stochastic steps. Note that if the CacheDatasetPlaceholder is included (see below) preprocessing is only re-applied after that step.
  • shard_info: An optional sharding specification for loading a deterministic subset of the dataset. Loading will be most efficient if the number of shards evenly divides the number of shards in the raw data source.
  • use_cached: Specifies whether to load from a pre-cached task for increased performance or to do the preprocessing on-the-fly. See the following section for details on how to cache your task, which must be done before this can be set to True.
  • seed: An optional seed to use for deterministic shuffling and (stateless) stochastic ops. These operations will still be pseudorandom but will be reproducible with the same seed. Set to None if determinism is not desired.

(Optional) Offline Caching

For improved performance at load time and to avoid redundant computations for commonly used tasks, you can pre-cache your Task with all or part of the preprocessing done in advance of training; this partial preprocessing is especially useful if the Task is stochastic and one wishes to cache the deterministic operations while running the stochastic ones on the fly. Caching stochastic SeqIO Mixtures in this way is not supported.

The first step to doing so is to add a seqio.CacheDatasetPlaceholder(required=False) as one of the steps in your preprocessing pipeline. All steps before the placeholder will be cached offline and all steps after will be executed on the fly at load time. You may set required=True if you want get_dataset to fail unless use_cached=True.

Caveats:

  • Any stochastic operations that you wish to be re-run when num_epochs > 1 or with a different seed should go after the placeholder since only a single sample will be cached.
  • Any preprocessing steps that use the sequence_length argument must come after the seqio.CacheDatasetPlaceholder preprocessor since this is only known at runtime, or an exception will be raised. If you wish to cache for a specific sequence length, you can use seqio.experimental.add_fully_cached_task.

Once your Task is registered, you can run cache_tasks_main to execute the offline preprocessing, providing it with the module containing your task definitions via the --module_import flag. For very large datasets, it's recommended you run this Apache Beam script on a distributed framework like Google Cloud DataFlow.

Finally, you are ready to load the cached version of your Task (or Mixture) containing it. You will need to add the path to the directory you passed to --output_cache_dir via seqio.add_global_cache_dirs(["/my/cache/dir"]). Now when you call task_or_mixture.get_dataset(..., use_cached=True), the data will be loaded from the cache directory instead of the raw data source.

Feature Converters

The role of Task is to provide the dataset object with as little model-specific features (e.g., generic "inputs" and "targets") while the Feature Converters transform the model-agnostic features to model-specific features (e.g., "encoder_input_tokens"). We refer to the former as "task features" and the latter as "model features".

Let's use machine translation (English to German) as a running example.

The raw data consists of sentence pairs such as

"That is good\tDas ist gut."

A task registered to Task (e.g., wmt_t2t_ende_v003) reads these sentence pairs from the data source and applies a series of preprocessors. One of the internal representations looks like

{"inputs": "translate English to German: That is good.",
 "targets": "Das ist gut."}

The final output from the Task is a tokenized version of the parallel sentences. In the following toy example (the token ids do not correspond to the above string example), the dataset consists of 2 examples.

dataset = [{"inputs": [7, 8, 5], "targets": [3, 9]},
           {"inputs": [8, 4, 9, 3], "targets": [4]}]

The format is in the tf.data.Dataset (i.e., each example is a dictionary with "inputs" and "targets" fields.

The FeatureConverter then takes this as an input and converts to the model-specific features. In addition, the feature converter performs padding and optionally packing (for model implementations that support it) for efficiency. For example, let's assume that we are using the standard Transformer architecture with an encoder and a decoder. The output of the feature converter is

converted_dataset = [{
    "encoder_input_tokens": [7, 8, 5, 1, 8, 4, 9, 3, 1, 0],
     "encoder_segment_ids": [1, 1, 1, 1, 2, 2, 2, 2, 2, 0],
       "encoder_positions": [0, 1, 2, 3, 0, 1, 2, 3, 4, 0],
   "decoder_target_tokens": [3, 9, 1, 4, 1, 0, 0],
    "decoder_input_tokens": [0, 3, 9, 0, 4, 0, 0],
    "decoder_loss_weights": [1, 1, 1, 1, 1, 0, 0],
       "decoder_positions": [0, 1, 2, 0, 1, 0, 0],
     "decoder_segment_ids": [1, 1, 1, 2, 2, 0, 0],
}]

In this case, two task examples are packed into one. *_segment_id and *_position are the fields used to denote the membership and position of packed token in the original sequence. The EOS ids (i.e., 1) are appended. In addition, each fields is padded to the specified length.

We will look at the details of this example in Encoder-decoder architecture: seqio.EncDecFeatureConverter section.

Feature converters provided out of the box

We provide feature converters for three common architectures: encoder-decoder, decoder-only and encoder-only. Here we describe how users can use the feature converters for each of these architectures out of the box as a part of the SeqIO library.

In the SeqIO library, each architecture has a class defining how the task features are converted to model features. Since these feature converters are already implemented, it is straightforward to use them by providing the class as a feature_converter argument of the seqio.get_dataset function. The following sections show example usage of seqio.get_dataset.

Encoder-decoder architecture: seqio.EncDecFeatureConverter

This is the architecture of the original Transformer paper. For the English-to-German translation task, the following function call retrieves the tf.data.Dataset object with the model features.

dataset: tf.data.Dataset = seqio.get_dataset(
    mixture_or_task_name="wmt_t2t_ende_v003",
    task_feature_lengths={"inputs": 32, "targets": 32},
    dataset_split="train",
    shuffle=True,
    feature_converter=seqio.EncDecFeatureConverter(pack=True)
)

The resulting dataset object has the following 7 fields

Feature name Explanation
encoder_input_tokens Input tokens to the encoder.
encoder_positions Position index in the sequence before packing.
encoder_segment_ids Sequence membership before packing. Two positions with
the same positive integer mean that they belong to the same sequence before
packing.
decoder_input_tokens Input tokens to the decoder.
decoder_target_tokens Output tokens from the decoder.
decoder_loss_weights A weight on each position that can be used as a mask.
decoder_positions Position index in the sequence before packing.
decoder_segment_ids Same as encoder_segment_ids but for decoder.
Decoder-only architecture

This architecture consists of a single autoregressive stack, which we denote as a "decoder".

A decoder autoregressively produces an output sequence. Therefore, it can be used as a standard language model if the task dataset has only "targets" features, i.e., self-supervised. If the task dataset also has an "inputs" field, e.g., supervised machine translation, the decoder can still be used by concatenating the inputs and targets fields. See Raffel et al. (2020), Section 3.2.1 for more detailed take on this topic.

We support both uses cases and refer to the former as standard language model and the latter as prefix language model. Each of these models is described separately below.

Note that we do not provide special features to denote how the dataset should be consumed. For example, a Transformer-based fully autoregressive decoder has a fully-causal self-attention layer. Since there are many ways of implementing the masking pattern for such attention layer and, more importantly, SeqIO is not limited to attention-based models, we leave it up to the model implementations to apply the masking pattern. There is one exception, and we cover this in the Prefix LM section below.

A common use pattern is to pretrain a decoder model with the left-to-right language modeling objective (unsupervised) using seqio.LMFeatureConverter and then fine-tune (supervised) using seqio.PrefixLMFeatureConverter.

Standard LM

For the standard language model, the task dataset only has "targets" field. Therefore, the sequence length specification only needs to specify targets.

dataset: tf.data.Dataset = seqio.get_dataset(
    mixture_or_task_name="standard_lm",
    task_feature_lengths={"targets": 32},
    dataset_split="train",
    shuffle=True,
    feature_converter=seqio.LMFeatureConverter(pack=True)
)

Note that "standard_lm" is not a registered task in the codebase. It is the left-to-right language modeling task, i.e., predict the next token given the previous tokens on some language corpus (e.g., C4).

The output dataset has the following model features.

Feature name Explanation
decoder_target_tokens Output tokens from the decoder
decoder_input_tokens Input tokens to the decoder
decoder_loss_weights Binary mask to indicate where the loss should be taken
decoder_positions Position index in the sequence before packing
decoder_segment_ids Sequence membership before packing. Two positions with
the same positive integer mean that they belong to the same sequence before
packing.

The decoder_target_tokens is a shifted version of decoder_input_tokens for the standard teacher-forced autoregressive training.

Prefix LM: seqio.PrefixLMFeatureConverter

If the input dataset has a notion of "inputs" and "targets", we can concatenate them so that we can still use a single stack decoder. Therefore, the output only contains "targets" just like standard LM case.

We use the same toy example for English-to-German translation task as a running example:

{"inputs": "translate English to German: That is good.",
 "targets": "Das ist gut."}

To be consumed by the decoder-only stack, seqio.PrefixLMFeatureConverter concatenates them form the new "targets". Consider 2-layer decoder architecture whose activations are shown below


That  is  good <EOS> Das ist gut <EOS>
 |    |    |    |    |   |    |   |
 u1   u2   u3   u4   u5  u6   u7  u8
 |    |    |    |    |   |    |   |
 v1   v2   v3   v4   v5  v6   v7  v8
 |    |    |    |    |   |    |   |
<BOS> That is  good <EOS> Das ist gut

Let us denote the first layer's activation in the ith position as vi. Similarly, let ui denote the activation of the second layer in the ith position.

For attention-based sequence models such as Transformer decoders, the self-attention layer is used to encode contextualized representation of the sequence. At a given layer, each position's representation is computed as a function of the representations of the tokens before its position in the previous layer.

Referring to the toy example, when computing u2 with fully-causal masking, we do not use v3. This results in a representation u2 of the word "is" that does not take into account the word "good", which is unnecessarily limiting.

For Prefix LM, this issue is resolved by having the fully visible masking pattern for the inputs portion only. For example, when computing u2, v1, v2, v3, v4 and v5 are all visible and taken into account. For the tokens in the "targets" of the Task dataset, we use the causal masking. For example, when computing u6, all vi for i <= 6 are taken into account but not v7.

Why is `v5` included in the inputs attention pattern? In the same translation example, we note that when computing `u2`, the activation corresponding to the position where \ token was input (i.e., `v5`) was visible. This doesn't count as "cheating" because the model doesn't see the next word "Das". This can provide additional context in building the representation for "good". In this case, `u4` has the context that "good" is the last word in the sentence.

seqio.PrefixLMFeatureConverter provides a feature decoder_causal_attention to encode this information. For the above example, we have

decoder_causal_attention = [1, 1, 1, 1, 1, 0, 0, 0]

indicating that the non-causal attention can be applied to the first five positions. Note that this feature seems trivial, but for a packed dataset the inputs and targets boundary are more nuanced.

A final consideration for the prefix LM is that because we concatenate "inputs" and "targets", which tokens are used for the loss computation is a modeling decision. For example, we can penalize the models only for the "targets" tokens or we may choose to penalize building the representation for "inputs" tokens. This is controlled by loss_on_targets_only argument (defaults to True) to seqio.PrefixLMFeatureConverter constructor. In the above example, we would get

decoder_loss_weights = [0, 0, 0, 0, 1, 1, 1, 1]

This indicates that the last 4 positions are used for the loss computation.

To get the dataset with prefix LM features, we can use

dataset: tf.data.Dataset = seqio.get_dataset(
    mixture_or_task_name="wmt_t2t_ende_v003",
    task_feature_lengths={"inputs": 32, "targets": 32},
    dataset_split="train",
    shuffle=True,
    feature_converter=seqio.PrefixLMFeatureConverter(
        pack=True,
        loss_on_targets_only=True)
)

The resulting features have length 64 because it concatenates inputs and targets each with length 32.

The output dataset has the following model features. Note that the only additional feature is decoder_causal_attention.

Feature name Explanation
decoder_target_tokens Output tokens from the decoder
decoder_input_tokens Input tokens to the decoder
decoder_loss_weights Binary mask to indicate where the loss should be
taken
decoder_positions Position index in the sequence before packing
decoder_segment_ids Sequence membership before packing. Two positions with
the ` same positive integer mean that they belong to the same sequence before
packing.
decoder_causal_attention Binary mask denoting which tokens are in the
non-causal masking region.
Encoder-only architecture

Like decoder-only architecture, this one is a single stack, but not autoregressive.

One notable assumption is that the inputs and targets are aligned, i.e., they have the same sequence length and ith position in the targets correspond to the output representation of the ith token in the inputs.

A common model using encoder-only architecture is BERT. We provide Encoder feature converter class to support the Masked Language Modeling (MLM) objective from BERT.

We assume that a unique sentinel such as [MASK] token is used to mask some fraction of the input text and the task is to recover the original text. Therefore, the "targets" is naturally defined as the original text whereas "inputs" are the masked text.

Encoder-only models are often used for classification tasks. In BERT, a special token [CLS] is prepended to the input sequence. The last layer's activation corresponding to this sentinel token is the contextualized representation of the sequence. We assume that such "classification" sentinel is prepended.

Consider the following example for the MLM task. The input dataset has two examples, which is packed to one example. We assume that mask_id = 9 and the [CLS] token has id of 8.

dataset = [{"inputs": [8, 9, 9, 3, 4], "targets": [8, 7, 4, 3, 4]},
           {"inputs": [8, 3, 9], "targets": [8, 3, 6]}]

converted_dataset = {
     "encoder_input_tokens": [8, 9, 9, 3, 4, 1, 8, 3, 9, 1, 0],
    "encoder_target_tokens": [8, 7, 4, 3, 4, 1, 8, 3, 6, 1, 0],
      "encoder_segment_ids": [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0],
        "encoder_positions": [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 0],
     "encoder_loss_weights": [0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
}

Note that the packed sequence has [CLS] token at the beginning of each sequences. Also note that the loss is taken only on the masked position.

To use the pre-defined EncoderFeatureConverter, provide mask_id as an argument.

dataset: tf.data.Dataset = seqio.get_dataset(
    mixture_or_task_name="some mlm task",
    task_feature_lengths={"inputs": 32, "targets": 32},
    dataset_split="train",
    shuffle=True,
    feature_converter=seqio.EncoderFeatureConverter(
        pack=True,
        mask_id=9)
)

The resulting dataset object has the following 5 fields

Feature name Explanation
encoder_input_tokens Input tokens to the encoder
encoder_positions Position index in the sequence before packing
encoder_segment_ids Sequence membership before packing. Two positions with
the ` same positive integer mean that they belong to the same sequence before
packing.
encoder_target_tokens Output tokens from the encoder
encoder_loss_weights Binary mask to indicate where the loss should be taken
Custom architectures

For a custom model architecture, you need to create a subclass of FeatureConverter and override two methods _convert_features and get_model_feature_lengths to define how task features are mapped to the model features, including the length relationships. The existing feature converters (e.g., seqio.EncDecFeatureConverter) follow the same pattern, which can be a useful starting point.

Evaluation

The SeqIO Evaluator class provides a way to evaluate models on SeqIO Tasks and Mixtures. For an interactive walkthrough of SeqIO evaluation, see the Evaluation Notebook. The following is a deep-dive into the Evaluator class.

An Evaluator instance can be created by passing a SeqIO Task or Mixture, and additional eval params like feature converter, split, sequence lengths, seed, etc. The Evaluator init calls get_dataset for each Task to be evaluated with the appropriate params, creating the task_dataset, and invokes the model-specific feature converter on the task_dataset to create features that can be passed to a model, called model_dataset. Both task_dataset and model_dataset are stored in-memory so that the dataset can be reused across multiple evaluations (e.g. on checkpoints from a training run). Both datasets are enumerated so that even if the order of examples is changed during model inference, the enumeration can be used to match model outputs to examples from the task_dataset.

For Mixtures, each sub-Task is evaluated separately, regardless of mixing rates, because in the context of eval benchmarks, Mixtures commonly refer to a collection of Tasks belonging to that benchmark, each of which is evaluated separately, e.g. SuperGLUE mixture.

Once an Evaluator instance is created with a SeqIO Task or Mixture, a model can be evaluated by calling evaluator.evaluate(...) and passing a predict_fn and/or a predict_with_aux_fn and/or a score_fn to interact with the model. predict_fn takes the model_dataset as input and outputs a Sequence[(index, token_ids)] where token_ids is the sequence of token ids generated by the model for the input example whose index matches index. Therefore, even if predict_fn mixes the order of the examples during prediction, the order can be corrected as long as the correct index for each example is maintained. A common example is the multi-host setup where the evaluation dataset is split amongst multiple hosts that independently make predictions and combine the results during which the ordering can be mixed. predict_with_aux_fn is similar to predict_fn, except that it can also return a dictionary of auxiliary values along with each sequence of token_ids, e.g. scores from the generated tokens. The score_fn takes the model_dataset as input and returns a Sequence[(index, score)] where score is the sequence of log likelihood scores for the targets in the dataset. This simple interface allows users to easily integrate the SeqIO evaluation flow with popular training frameworks in TF and Jax.

Corresponding to the model fns, users can configure three kinds of metric fns in their Tasks, which are differentiated by their function signature. Metrics computed on the outputs of predict_fn (and predict_with_aux_fn) have the signature targets and predictions (and optionally aux_values), while metrics computed on the outputs of score_fn have the signature targets and scores. The Evaluator takes care of calling the correct model fns and metric fns during evaluation. Here is an example of a metric of each type.

def sequence_accuracy(targets, predictions):
 seq_acc = 100 * np.mean([p == t for p, t in zip(predictions, targets)])
 return {"sequence_accuracy": seq_acc}

def log_likelihood(targets, scores):
 log_likelihood = np.mean([scipy.special.logsumexp(el) for el in scores])
 return {"log_likelihood": log_likelihood}

There are 4 steps involved in the evaluation using predicted tokens:

  • the predict_fn or predict_with_aux_fn returns indices and output_tokens: Sequence[Tuple[int, Sequence[int]]], potentially with some auxiliary values.
  • output tokens are decoded by vocab.decode
  • postprocessors configured in Tasks are applied to the decoded output. These are denoted as predictions.
  • metric fns configured in Tasks are applied to the predictions and the cached targets.

There are 2 steps involved in the evaluation using scores:

  • the score_fn returns indices and scores: Sequence[Tuple[int, Sequence[float]]]
  • metric fns configured in Tasks is applied to the scores and the cached targets.

Training codebases like T5X provide integration with SeqIO evaluation to allow evaluating checkpoints on SeqIO Tasks and Mixtures. See T5X Eval for instructions.

Differences from t5.data

The original t5 library introduced and implemented the t5.data.Task abstraction for specifying preprocessing and evaluation metrics for text-to-text tasks. When creating a task, users specify a source dataset of raw text, some preprocessing steps, a vocabulary for tokenization, and evaluation metrics. The fully-specified Task can then be used to pre-train or fine-tune a encoder-decoder transformer model. However, the design included many baked-in assumptions about the types of tasks users could specify.

SeqIO removes some of the constraints of this abstraction:

  • Inputs and outputs are no longer required to be strings (e.g., it may be images or audio).
  • Architectures other than the original encoder-decoder are supported (e.g., decoder-only language models like GPT or encoder-only models like BERT).
  • Users can control at which stage of the pipeline offline caching occurs.
  • Users can control when and where EOS tokens are added.

Furthermore, SeqIO has been made more modular with respect to the Mesh TensorFlow Transformer. This allows it to be used with other model implementations with more consistency and much less code duplication.

Advanced Postprocessing Task

TriviaQA (Closed-book, open-domain version)

This version of TriviaQA was introduced in Roberts et al. 2020.

seqio.TaskRegistry.add(
    "trivia_qa_open",
    source=seqio.TfdsDataSource(
      tfds_name="trivia_qa/unfiltered.nocontext:1.1.0",
      splits={
          "train": "train[:90%]",
          "validation": "train[90%:]",
          "test": "validation"
      }),
    preprocessors=[
        tqa_open_preprocessor,
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
    ],
    output_features={
        "inputs": seqio.Feature(
           seqio.SentencePieceVocabulary("/path/to/inputs/vocab"),
           add_eos=False, dtype=tf.int32
        ),
        "targets": seqio.Feature(
           seqio.SentencePieceVocabulary("/path/to/targets/vocab"),
           add_eos=True, dtype=tf.int32
        ),
    },
    postprocess_fn=tqa_open_postprocessor,
    metric_fns=[tqa_metric])

In this example, we are using the TfdsDataSource. We specify the name of the TriviaQA dataset in TFDS ("trivia_qa"), the specific config that excludes the context for the open domain setting ("unfiltered.nocontext"), and the version number ("1.1.0"). We also override the default splits to match what is commonly used for the open domain setting. Specifically, we set our "test" split to be the TFDS "validation" split, and create a small pseudo-"validation" set by taking examples out of the TFDS "train" split.

The preprocessor tqa_open_preprocessor is defined as follows.

def tqa_open_preprocessor(
    dataset: tf.data.Dataset,
    prefix:str = "trivia_qa question: "
  ) -> tf.data.Dataset:
  """Convert TriviaQA dataset to open domain qa examples.

  The function takes the trivia_qa TFDS dataset and emits examples of the
  form:
  {
    "inputs": "trivia_qa question: What are the names of the Olsen Twins?"
    "targets": "Mary-Kate and Ashley",
    "answers": ["Mary-Kate and Ashley", "Ashley and Mary-Kate"]
  }

  Args:
    dataset: a tf.data.Dataset to process.
    prefix: str, prefix to prepend to the inputs.

  Returns:
    a tf.data.Dataset
  """
  def tqa_map(ex):
    """Map TriviaQA example to text-to-text example."""
    return {
        "inputs": prefix + ex["question"],
        "targets": ex["answer"]["value"],
        "answers": ex["answer"]["aliases"],
    }

  return dataset.map(tqa_map, num_parallel_calls=tf.data.experimental.AUTOTUNE)

Or with the seqio.map_overdataset decorator, we have

def tqa_open_preprocessor(
  dataset: tf.data.Dataset,
  prefix: str = "trivia_qa question: "
) -> tf.data.Dataset:

  @seqio.map_over_dataset
  def tqa_map(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
    """Map TriviaQA example to text-to-text example."""
    return {
        "inputs": prefix + ex["question"],
        "targets": ex["answer"]["value"],
        "answers": ex["answer"]["aliases"],
    }

return tqa_map(dataset)

Here we made a thin wrapper to emphasize that the function decorated by seqio.map_over_dataset takes in an instance of tf.data.Dataset. In practice, this wrapper is not necessary.

The postprocessor for this example is tqa_open_postprocessor, which is defined as follows:

def tqa_open_postprocessor(output_or_target, example=None, is_target=False):
  """Returns output as answer, or all answers if the full example is provided."""
  if is_target:
    return [a.decode("utf-8") for a in example["answers"]]
  else:
    return output_or_target.decode("utf-8")

When processing the target, we ignore output_or_target (equivalent to example["targets"]) since it is just selecting a single answer in trivia_qa_open. Instead, we extract the full list of answers from the example and convert them from bytes to text. When handling the model output, we simply convert it to text from detokenized bytes.

The metric function tqa_metric is defined as:

def tqa_metric(
  targets: Sequence[Sequence[str]],
  predictions: Sequence[str]
) -> Mapping[str, seqio.metrics.MetricValueValue]:
  """Computes official TriviaQA metrics.

  Args:
    targets: list of lists of strings
    predictions: list of strings

  Returns:
    dict with score_key: squad score across all targets and predictions
  """

  if len(targets) != len(predictions):
    raise ValueError("Number of targets and predictions must match.")

  def _normalize_answer(text):
    """Lower text and remove punctuation, articles and extra whitespace."""
    # Remove articles.
    text = re.sub(r"\b(a|an|the)\b", " ", s)
    # Remove punctuation.
    for punc in string.punctuation:
      text = text.replace(punc, '')
    # Normalize white space
    text = " ".join(s.split())
    return text

  # Normalize answers before comparing.
  targets = [[_normalize_answer(t) for t in u] for u in targets]
  predictions = [_normalize_answer(p) for p in predictions]

  em = np.mean([
      max(pred == gt for gt in ground_truths)
      for pred, ground_truths in zip(predictions, targets)
  ])
  return {
      "exact_match": seqio.metrics.Scalar(em),
  }

Citing SeqIO

Please use the following bibtex entry to cite SeqIO.

@article{roberts2022t5x,
  url = {https://arxiv.org/abs/2203.17189},
  author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra,
  Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester,
  Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and
  Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and
  Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and
  Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and
  Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy,
  Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and
  Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and
  Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel,
  Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov,
  Alexander and Newlan, Joshua and Gesmundo, Andrea},
  title = {Scaling Up Models and Data with $\texttt{t5x}$ and $\texttt{seqio}$},
  journal={arXiv preprint arXiv:2203.17189},
  year = {2022},
}

seqio's People

Contributors

0x0539 avatar adarob avatar afrozenator avatar alextp avatar blester125 avatar broken avatar cghawthorne avatar dhgarrette avatar fineguy avatar gauravmishra avatar hawkinsp avatar hwchung27 avatar iansimon avatar joshnewlan avatar kehang avatar kkenealy avatar liviosoares avatar marvin182 avatar mlbileschi avatar nshazeer avatar ppwwyyxx avatar qstanczyk avatar rchen152 avatar rhofour avatar sharannarang avatar texasmichelle avatar theexgenesis avatar tomvdw avatar xingyousong avatar zcharles8 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

seqio's Issues

Different preprocessors for each dataset split

Hi,

I'm working on an STS task and using a seqio.TfdsDataSource task and t5.data.preprocessors.stsb as preprocessor. After seeing some generated examples I realize that both training and eval data are preprocessed, so metrics calculated on test split use a processed version of the target instead of the original values. Using t5.data.preprocessors.stsb as example, during training a label value of 3.25 is converted to "3.2"; during eval steps, I'd like to just convert this value to float without rounding ("3.25").

Is there a way to apply different preprocessors for each dataset split? It would be ideal for the evaluation metrics functions to consume gold labels as close as possible to the original values. The postprocessing section on README mentions is_target argument for postprocessing functions, but I could't find a similar instruction for preprocessor functions.

Thanks,
Marcos

caching tasks goes out of memory due to apache beam

Trying to cache tasks from magenta/MT3 repository, only with 200 examples it takes around 30GB of memory while caching at the very end of processing.
Without caching it trains just fine even with 1000 train examples train dataset.

seqio 0.0.13 cannot be installed on Apple Silicon due to transitive tensorflow dependency of clu

db4d4b0 added clu as a dependency of seqio.

With this change, we can no longer install seqio on Apple Silicon machines (e.g. M1, M2). This is because clu requires tensorflow (https://github.com/google/CommonLoopUtils/blob/85f9d28556f2684e2c5f2e412cbef5119d6682ba/setup.py#L54) but on Apple Silicon tensorflow should be installed as tensorflow-macos based on the instructions at https://developer.apple.com/metal/tensorflow-plugin/.

A simple fix is to update the clu tensorflow line in the setup.py to tensorflow; platform_machine == 'x86_64'. However, that project doesn't accept GitHub issues or contributions so I am creating an issue here.

Possible ByteVocabulary Bug

464   def _decode_tf(self, ids):
465     """Decode in TensorFlow.
466 
467     Args:
468       ids: a 1d tf.Tensor with dtype tf.int32
469     Returns:
470       a tf Scalar with dtype tf.string
471     """
472     return tf.py_function(func=self.decode, inp=[ids], Tout=tf.string)

The param 'ids' passed to _decode_tf above is a 1d tf.Tensor, and on line 472 it is wrapped into a list, but when self.decode is called with the param [ids], it throws the error in the list comprehension on line 100 (shown below):

File "/home/ubuntu/anaconda3/envs/google_t5/lib/python3.7/site-packages/seqio/vocabularies.py", line 100, in
decode
for i in clean_ids

File "/home/ubuntu/anaconda3/envs/google_t5/lib/python3.7/site-packages/seqio/vocabularies.py", line 100, in

for i in clean_ids

File "/home/ubuntu/anaconda3/envs/google_t5/lib/python3.7/site-packages/tensorflow/python/framework/ops.py",
line 1007, in bool
return bool(self._numpy())

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

 92   def decode(self, ids: Iterable[int]):
 93     """Detokenizes int32 iterable to a string, up through first EOS."""
 94     clean_ids = list(ids)
 95 
 96     if self.unk_id is not None:
 97       vocab_size = self._base_vocab_size
 98       clean_ids = [
 99           self.unk_id if i >= vocab_size else i
100           for i in clean_ids
101       ]
102 
103     if self.eos_id is not None and self.eos_id in clean_ids:
104       clean_ids = clean_ids[:clean_ids.index(self.eos_id) + 1]
105 
106     return self._decode(clean_ids)

How to apply the huggingface tokenizer in seqio.vocabulary

Hello.

I would like to use the huggingface tokenizer to seqio.vocabulary in t5x.

I inherited seqio.vocabulary and created my BBPEVocabulary. However, the values 'inputs' and 'targets' are not accessed as text in tf.data.Dataset.map. Because huggingface tokenizer can get a string but tf.data.Dataset give tf.tensor like Tensor("args_0:0", shape=(), dtype=string).

Since the seqio.sentencepice module load module by using so file in tf_text.sentencepiece, I don't know how to handle it inside.

I would like to ask you about how to get and process tf.tensor as text in order to use huggingface tokenizer in tf.data.Dataset map.

I am attaching the code I used below.

Thank you:)

seqio/custom_task.py

from src.vocabularies import BBPEVocabulary
bbpe_vocab = BBPEVocabulary('custom_path')

seqio.TaskRegistry.add(
    "my_span_corruption_task",
    source=seqio.TFExampleDataSource(
        split_to_filepattern={"train": os.path.join('[MY_TF_RECORD_PATH]', "*train.tfrecord*")},
        feature_description={"text": tf.io.FixedLenFeature([], tf.string)}
    ),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption,
        seqio.preprocessors.append_eos_after_trim,

    ],
    output_features=BBPE_OUTPUT_FEATURES,
    metric_fns=[])

seqio/preprocessors.py

def tokenize(dataset: tf.data.Dataset,
             output_features: OutputFeaturesType,
             copy_pretokenized: bool = True,
             with_eos: bool = False) -> tf.data.Dataset:
  tokenize_fn = functools.partial(
      tokenize_impl,
      output_features=output_features,
      copy_pretokenized=copy_pretokenized,
      with_eos=with_eos)
  return utils.map_over_dataset(fn=tokenize_fn)(dataset)

def tokenize_impl(features: Mapping[str, tf.Tensor],
                  output_features: OutputFeaturesType,
                  copy_pretokenized: bool = True,
                  with_eos: bool = False) -> Mapping[str, tf.Tensor]:
  ret = {}
  for k, v in features.items():
    if k in output_features:
      if copy_pretokenized:
        ret[f'{k}_pretokenized'] = v
      vocab = output_features[k].vocabulary
      v = vocab.encode_tf(v) # In this line, the `v` value type is "tf.tensor", and I can't obtain text of `v`
      ...[omitted]...

    ret[k] = v
  print(f'tokenize_impl | complete | return : {ret}')
  return ret

Using seqio for T5X Dataset Generation

Hi ๐Ÿค—

I would like to pre-train a T5 Base model with T5X library.

When I understand the pre-training process correctly, I need TFRecords stored on a cloud bucket for that training (like it is done for BERT pre-training).

Now I have the following questions:

How is possible to generate such a dataset from an own corpus. Corpus is a plain text file (each line = one sentence). I have also a T5 compatible vocab (sentencepiece model), because I don't want to use the existing T5 or mT5 vocabs.

Many thanks advance!

Unimax sampler implementation?

I tried searching the code for seqio mixtures generated using the newly released unimax sampler.

I am trying to pretrain a custom umUL2 model. if i could perhaps know how to implement unimax it would be of great help. Thanks

Tokenizer is not behaving as expected on special tokens (doesn't recognize `pad` and `eos` tokens)

Looks like tokens like eos and pad do not get tokenized correctly:

Repro:

In [1]: import seqio

In [2]: vocab = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000/sentencepiece.model')

In [3]: vocab.tokenizer.id_to_piece(0)
Out[3]: '<pad>'

In [4]: vocab.tokenizer.id_to_piece(1)
Out[4]: '</s>'

In [5]: vocab.encode(vocab.tokenizer.id_to_piece(1))
Out[5]: [3, 2, 87, 7, 3155]

In [6]: vocab.tokenizer.id_to_piece(vocab.tokenizer.encode(vocab.tokenizer.id_to_piece(1)))
Out[6]: ['โ–', '<unk>', '/', 's', '>']

It breaks down the special tokens into wordpieces.

import seqio

Hello why I am getting this warning for just importing seqio

import seqio
2022-07-29 10:38:38.223245: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-07-29 10:38:38.223295: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

How to just use the mixture functionality in seqio

Hey there, I've been wanting to pretrain MT5 on Huggingface training script as mentioned here: https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py

But sadly the Huggingface script doesn't support a mixture to pretrain MT5 in such a way that the model generalise well on low-resource as well as high-resource langauges.

Hence I've been wanting to use the mixture functionality of seqio, but sadly upon using it i have to tokenize the model into the T5 sentencepiece vocabulary and seqio tasks does all the preprocessing.

The Huggingface trainer takes care of the preprocessing maping the dataset to the tokenizer etc.

My question is is there a way where i could only just use the mixture functionality of seqio without actually doing any preprocessing on the incoming datasets.

I was wondering if there is a way to feed in multiple datasets, get an output dataset (in text str format) which is only an appropriate mixture of all samples of the datsets, passed by the mixture function. which i could then use to pretrain on the HF trainer and then do all the preprocessing on it in HF trainer

TfdsDataProvider gives error with non-None tfds_data_dir

SeqIO provides access to TFDS through TfdsDataProvider, which takes tfds_data_dir as an argument. However, it is not currently possible to use a non-None tfds_data_dir with TfdsDataProvider.

The issue can be traced to LazyTfdsLoader, which uses tfds.load with the hardcoded setting try_gcs=True. As noted in the TFDS docs, this is equivalent to setting data_dir='gs://tfds-data/datasets'. Consequently, TFDS raises an error when passing try_gcs=True and a non-None data_dir to tfds.load, as would occur when using non-None tfds_data_dir with TfdsDataProvider.

I believe allowing a non-None tfds_data_dir would be helpful in many scenarios. For example, many large datasets available through TFDS are hosted in locations other than gs://tfds-data/datasets, and in formats other than tfrecords. When downloading and preprocessing such datasets on preemptable VMs it is desirable to specify a data_dir to allow one to save tfrecords to the cloud as detailed here. This allows users to avoid incurring the full download/processing delay on subsequent occasions: only some tfrecord shards need to be downloaded per host, and the downloads can overlap with model training. In this case, however, one must set try_gcs=False, to avoid the TFDS error.

Rather than exposing the try_gcs option to the user, the implementation of LazyTfdsLoader can automatically set try_gcs=False when data_dir is not None. This way, it would be the user's responsibility to specify a tfds_data_dir or not when instantiating TfdsDataProvider, just like they are nominally able to do right now. The only downside is that if the data is available as tfrecords at gs://tfds-data/datasets and the user specifies their own data_dir, the try_gcs=False forces a potentially unnecessary download. However, a warning can be added to the docstring to mention this consequence of specifying tfds_data_dir.

Can we make this happen? I can open a PR with this simple change!

Dataset performance

I am having difficult time getting my data pipeline to the throughput levels that I would like before starting training with the t5x library.

Initially I planned to use a mixture of ~40 tasks (1-2 TB text) for training and started doing some benchmarking following general TPU and dataset performance tips. Here are some useful guides that I tried to follow:

All of my datasets/tasks are json line files (output from earlier dataflow jobs) varying from 200 to 1000 files.

I used colab notebooks or an E2 32 cpu instance during my benchmarking experiments where I mounted my bucket which has all the ~40 datasets that I plan to use. I sampled 16 different files as training files for each task source because it is recommended not to have to read too many files form GCS.

FileDataSource

I switched from FunctionDataSource to FileDataSource , This is mainly to use individuals files during sharding without needing to read all the data which I assume would be slower especially for larger datasets.

import json
@tf.autograph.experimental.do_not_convert
def read_file_fn(file):
  """
  """
  def _read_json(file):
    # file = file.numpy().decode()
    with open(file) as f:
      for line in f:
        yield json.loads(line)['text']
  
  return tf.data.Dataset.from_generator(_read_json, args=(file,),
      output_signature=tf.TensorSpec(
          shape=(), dtype=tf.string, name=name)
  )      

source = seqio.FileDataSource(
  read_file_fn = read_file_fn,
  split_to_filepattern=dict(train=train_files, validation=validation_files))

Here we can see the reading and deserialization performance of a single task source.

dataset = source.get_dataset("train", shard_info=seqio.ShardInfo(0,16))
tfds.benchmark(dataset, num_iter=10000)
Examples/sec (First included) 1622.67 ex/sec (total: 10001 ex, 6.16 sec)
Examples/sec (First only) 0.95 ex/sec (total: 1 ex, 1.05 sec)
Examples/sec (First excluded) 1954.66 ex/sec (total: 10000 ex, 5.12 sec)

Single Task

Then I register my seqio tasks with full pipeline (including preprocessors) and test the performance of a single task.

dataset = seqio.get_mixture_or_task('task').get_dataset(
                    sequence_length={"inputs": 512, "targets": 512},
                    split="train",
                    shuffle=False,
                    num_epochs=1,
                    shard_info=seqio.ShardInfo(index=0, num_shards=16),
                    use_cached=False,
                    seed=42)
tfds.benchmark(dataset, num_iter=10000)
Examples/sec (First included) 485.21 ex/sec (total: 10001 ex, 20.61 sec)
Examples/sec (First only) 0.47 ex/sec (total: 1 ex, 2.11 sec)
Examples/sec (First excluded) 540.50 ex/sec (total: 10000 ex, 18.50 sec)

Mixture

When I benchmark the performance of the mixture it drops significantly (10x).

dataset = seqio.get_mixture_or_task("maana_version1.0_mixture").get_dataset(
                    sequence_length={"inputs": 512, "targets": 512},
                    split="train",
                    shuffle=False,
                    num_epochs=1,
                    shard_info=seqio.ShardInfo(index=0, num_shards=16),
                    use_cached=False,
                    seed=42)
tfds.benchmark(dataset, num_iter=10000)
Examples/sec (First included) 140.55 ex/sec (total: 10001 ex, 71.16 sec)
Examples/sec (First only) 0.09 ex/sec (total: 1 ex, 11.49 sec)
Examples/sec (First excluded) 167.60 ex/sec (total: 10000 ex, 59.67 sec)

Follow Up Thoughts

Please let me know if you have any feedback regarding the following comments and questions:

  1. In my experiments reading from GCS vs local files didn't differ much. So streaming directly from GCS is probably the better option (not having to download TB size data) as long as bucket is in the same zone as TPU and number of files is not too much. Documents state (10s to 100s MB) and (10s to 100s files), in my case I have datasets with 200-1000 files (100 MB-1 GB range), should I reduce the number of files maybe by making each 1 GB, would this help pipeline performance?

  2. I also experimented with TFExampleDataSource vs FileDataSource didn't see any performance gain from TFExample compared to json. Is there an absolute best way to store data for seqio pipeline performance, e.g. would registering a tfds be better - as explained here? In my experience dataflow jobs output number of files equal to the number of workers, so it can be much higher than 100s. Is this ok or should we keep the number of files in 128-256 range?

  3. This is more of a T5X question but still might be related. My understanding is that when we get dataset from a mixture each task is iterated and if there is shard info specified that shard is returned as data, later same sample_fn is used for sampling from these task datasets with the given rates. I don't fully know how data parallelism plays together with model parallelism in t5x and maybe it might depend on the model size and # of tpus cores we have. Is it correct to assume each TPU core is a worker and data gets distributed to them when sharding? So would it make sense to have as many files as a multiple of core numbers (e.g. 8x for v3-8, 32x for v3-32). I also read that batch is automatically distributed across tpu cores when doing computation that is why I guess 8 x 128 is emphasized, then does it mean we don't need to necessarily care about number of files / sharding and still can use a single source file?

Notes from codelab:

The rule of thumb is to split your data across several (10s to 100s) larg-ish files (10s to 100s of MB). If you have too many files, thousands of files for example, the time to access each file might start getting in the way. If you have too few files, like one or two, then you are not getting the benefits of streaming from multiple files in parallel.

how to decide ideal mixture rates ?

what is the best way to decide on which mixture ratio is optimal?

In the mT5 paper the alpha value 0.3 gave the best balance between ideal performance for high and low resource languages.

However I am pretraining mT5 on Indian languages, and I have a diverse variety of indian multi-lingual corpus, where Hindi has 60M+ samples and Kashmiri has around 100k samples.

So I wanted to know if I could h-param tune somehow on t5x, or would just using alpha=0.3 work fine in my use case?

Using a registered task to add another

Suppose I have a task registered as follows:

seqio.TaskRegistry.add(
    "task_1",
    source=seqio.TfdsDataSource(tfds_name="c4/en:3.0.1", splits=["train", "validation"]),
    preprocessors=[
        preprocess1,
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        preprocess2,
        preprocess3,
],
output_features=...

Is it possible to add another task that starts with the cached part of task_1, i.e., the part before seqio.CacheDatasetPlaceholder(), and only vary preprocess2 and preprocess3?

I'm looking for something like

seqio.TaskRegistry.add(
"task_2",
source=seqio.CachedTaskSource("task_1", ...),
preprocessors=[
    preprocess2_modified,
    preprocess3_modified,
],
...

Concatenating Tasks?

Is there a way to concatenate multiple Tasks? Mixtures sample from component Tasks until one of them runs out of examples. Is there a variant that uses all of the examples from both Tasks in each epoch?

`tokenize_and_append_eos` needs another requrired input (`output_features`)

tokenize_and_append_eos needs another requrired input (output_features) how can I use this function as preprocessor, how to pass output features?

It's the way I tired to use it

preprocessors=[
          functools.partial(
              t5.data.preprocessors.parse_tsv,
              field_names=["input_text", "target_text"]),
          seqio.preprocessors.tokenize_and_append_eos,
    ],

HuggingFace Tokenizers compatibility

Hi, I have been trying to get SeqIO to work with HuggingFace's tokenizers for a bit but have been running into trouble with non-t5 based tokenizers. Specifically, it seems that, because they are not sentencepiece tokenizers, tokenizers for models such as GPT-2 are incompatible with SeqIO's SentencePieceVocabulary as they only have the vocab files:

{
  'vocab_file': 'vocab.json',
  'merges_file': 'merges.txt',
  'tokenizer_file': 'tokenizer.json'
}

Is there a currently supported way to use these tokenizers with SeqIO? Or would I need to make my own vocab class?

seqio_cache_tasks fails on DataflowRunner

When trying to cache a dataset that does not fit DirectRunner (e.g google-research/text-to-text-transfer-transformer#323 (comment)) on Cloud Dataflow without any requirements.txt, like

python -m seqio.scripts.cache_tasks_main \
 --module_import="..." \
 --tasks="${TASK_NAME}" \
 --output_cache_dir="${BUCKET}/cache" \
 --alsologtostderr \
 --pipeline_options="--runner=DataflowRunner,--project=$PROJECT,--region=$REGION,--job_name=$TASK_NAME,--staging_location=$BUCKET/binaries,--temp_location=$BUCKET/tmp,--experiments=shuffle_mode=appliance"

it fails with ModuleNotFoundError: No module named 'seqio'.

If seqio added with

echo seqio > /tmp/beam_requirements.txt

# and run the same, adding to `--pipeline_options`
--requirements_file=/tmp/beam_requirements.txt

it fails with

subprocess.CalledProcessError: Command '['.../.venv/bin/python', '-m', 'pip', 'download', '--dest', '..../pip-tmp
/dataflow-requirements-cache', '-r', '/tmp/beam_requirements.txt', '--exists-action', 'i', '--no-binary', ':all:']' returned non-zero exit status 1.

 Pip install failed for package: -r
 Output from execution of subprocess: b"ERROR: Could not find a version that satisfies the requirement tensorflow-text (from versions: none)\
nERROR: No matching distribution found for tensorflow-text

This seems to be cause by seqio depending on tensorflow-text, which does not have any source release artifacts.

But requirements cache in Apache Beam seem to be populated with --no-binary :all: before making it available to the workers.

A try on a clean venv results in the same:

pip3 install  --no-binary :all: --no-deps tensorflow-text==2.6.0
ERROR: Could not find a version that satisfies the requirement tensorflow-text==2.6.0 (from versions: none)
ERROR: No matching distribution found for tensorflow-text==2.6.0

Am I doing something wrong, or how does everyone work this around? Would appreciate a hand here.

How to choose minimum sequence length while avoiding truncation

Hi,

I have a task that uses seqio.TfdsDataSource as its source and a pipeline with preprocessors final steps that looks like this: [..., seqio.preprocessors.tokenize, seqio.CacheDatasetPlaceholder(), seqio.preprocessors.append_eos_after_trim].

I have cached this task, so I know the maximum token lengths for both inputs and targets.

My question is: when training a model with t5.models.mesh_transformer_main using this task and providing gin bindings for utils.run.sequence_length, should I use the values I see on the cached stats, or should I add +1 to account for the EOS token? My goal is to avoid data truncation by specifying smaller sequence lengths than what my data requires.

(P.S.: I know this is also related to the t5 repository, but I opened the issue here because I think my question is related to the seqio.preprocessors.append_eos_after_trim function. If you think it would be more appropriate to open this issue in another repository, please let me know, and I can change it.)

Thanks in advance,
Marcos

Dataset seeking for restarting from a T5X crashed run using HuggingFace datasets

Re-opening here as suggested by @adarob in google-research/t5x#421 (comment).

I wrote some hacky support for HuggingFace datasets using seqio.FunctionDataSource, specifically for pretraining and further pretraining models using T5X.

def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
    dataset = load_dataset(**dataset_params)
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:  # TODO: add for...loop over num_epochs
        for item in dataset[str(split)]:
            yield item[column]

def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
    )

dataset_name = 'NbAiLab/NCC'
dataset_params = {"path": dataset_name, "streaming": True}
dataset_shapes = {"train": 20830348, "validation": 473079}
source = seqio.FunctionDataSource(
    dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
    splits=("train", "validation"),
    caching_permitted=False,
    num_input_examples=dataset_shapes,
)

But unfortunately, as I face constant random crashes during training (google-research/t5x#366), I need a way to seek to the right dataset batch to properly continue training.

I see there's a continue_from_last_checkpoint variable in get_dataset(), bit it seems is not used for anything yet.

Is there a way to pass in the needed information to get_dataset_fn() so I can write the logic without using any hard-coded global variables?

import seqio: AttributeError: module 'typing' has no attribute 'get_origin'

import seqio
Traceback (most recent call last):
File "", line 1, in
File "/usr/local/lib/python3.7/dist-packages/seqio/init.py", line 18, in
from seqio.dataset_providers import *
File "/usr/local/lib/python3.7/dist-packages/seqio/dataset_providers.py", line 38, in
import pyglove as pg
File "/usr/local/lib/python3.7/dist-packages/pyglove/init.py", line 30, in
from pyglove.core import *
File "/usr/local/lib/python3.7/dist-packages/pyglove/core/init.py", line 56, in
from pyglove.core import symbolic
File "/usr/local/lib/python3.7/dist-packages/pyglove/core/symbolic/init.py", line 93, in
from pyglove.core.symbolic.diff import diff
File "/usr/local/lib/python3.7/dist-packages/pyglove/core/symbolic/diff.py", line 153, in
(pg_typing.StrKey(), pg_typing.Object(Diff), 'Child node.')
File "/usr/local/lib/python3.7/dist-packages/pyglove/core/typing/value_specs.py", line 1279, in init
schema_or_field_list, allow_nonconst_keys=True)
File "/usr/local/lib/python3.7/dist-packages/pyglove/core/typing/class_schema.py", line 1179, in create_schema
value = ValueSpec.from_annotation(maybe_value_spec, True)
File "/usr/local/lib/python3.7/dist-packages/pyglove/core/typing/value_specs.py", line 2131, in _from_annotation
origin = typing.get_origin(annotation)
AttributeError: module 'typing' has no attribute 'get_origin'

unable to train mt5 from t5x using mixtures ValueError: Dataset is missing an expected feature during input_validation validation: 'inputs'

Hey there,

I am currently pretraining mt5 model on 23 different languages. but when i create a mixture and set the mixture name in t5x .gin config file for training on the mixture i get the following error.

ValueError: Dataset is missing an expected feature during input_validation validation: 'inputs'

However when i individually ran the independent tasks by setting them in the gin file everything works fine.

the following is how my task.py file looks like.

TaskRegistry = seqio.TaskRegistry
vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)


DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=vocabulary, add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=vocabulary, add_eos=True)
}



def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_path=None):
    dataset = load_dataset(dataset_path, streaming=True, use_auth_token=True)
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            yield item[column]


def dataset_fn(split, shuffle_files, seed=None, dataset_path=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files, seed, dataset_path=dataset_path),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_path)
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}


TaskRegistry.add(
    "urdu_span_curruption",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset_path='StephennFernandes/ciil_mega_corpus_urdu'),
        splits=("train", "validation"),
        caching_permitted=False),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        # seqio.CacheDatasetPlaceholder(),
        preprocessors.span_corruption, 
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
    metric_fns=[]
)
 ### similar multiple languages are loaded here ### 


#seqio mixture 3.5 
seqio.MixtureRegistry.add(
  "ciil_mix_3.5",
  ["assamese_span_curruption", "bengali_span_curruption", 
  "bhisnupuriya_span_curruption", "bodo_span_curruption", 
  "divehi_span_curruption", "dogri_span_curruption", 
  "english_span_curruption", "gujarati_span_curruption",
  "hindi_span_curruption", "kannada_span_curruption", 
  "kashmiri_span_curruption", "konkani_span_curruption", 
  "maithili_span_curruption", "malayalam_span_curruption",
  "manipuri_span_curruption", "marathi_span_curruption",
  "nepali_span_curruption", "odia_span_curruption",
  "panjabi_span_curruption", "sanskrit_span_curruption",
  "tamil_span_curruption", "telugu_span_curruption",
   "urdu_span_curruption" ],
  default_rate=3.5
)

upon running the mt5 model with the mixture name in the .gin file i get the following error:

File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 744, in _main
    train_using_gin()
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 249, in train
    train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1366, in get_dataset
    return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
  File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1387, in get_dataset_inner
    ds = seqio.get_dataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1681, in get_dataset
    ds = feature_converter(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/feature_converters.py", line 404, in __call__
    ds = self._validate_dataset(
  File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/feature_converters.py", line 294, in _validate_dataset
    raise ValueError("Dataset is missing an expected feature during "
ValueError: Dataset is missing an expected feature during input_validation validation: 'inputs'

FunctionDataSource does not allow function with 3 positional arguments thus shuffling does not work

During creation it checks if function has only 2 positional arguments. For shuffling to be used it should also accept a third argument, seed or seeds. Otherwise an exception is thrown when trying to pass shuffle=True to get_dataset().

_validate_args(dataset_fn, ["split", "shuffle_files"])

Also it only allows seed and not seeds later. But this never comes into effect since the whole things fails during creation.

_validate_args(self._dataset_fn, ["split", "shuffle_files", "seed"])

ValueError: mutable default <class 'seqio.vocabularies.PassThroughVocabulary'> for field vocabulary is not allowed: use default_factory

Traceback (most recent call last):
File "/scenic/scenic/projects/vid2seq/vid2seq_test.py", line 13, in
from scenic.projects.vid2seq import trainer
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/scenic/projects/vid2seq/trainer.py", line 26, in
from scenic.projects.t5 import model as t5_model
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/scenic/projects/t5/model.py", line 29, in
from scenic.projects.t5 import layers
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/scenic/projects/t5/layers.py", line 9, in
from t5x import decoding
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/t5x/init.py", line 17, in
import t5x.adafactor
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/t5x/adafactor.py", line 64, in
from t5x import utils
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/t5x/utils.py", line 44, in
import seqio
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/seqio/init.py", line 18, in
from seqio.dataset_providers import *
File "/anaconda3/envs/vid2seq/lib/python3.11/site-packages/seqio/dataset_providers.py", line 60, in
@dataclasses.dataclass(frozen=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/vid2seq/lib/python3.11/dataclasses.py", line 1213, in wrap
return _process_class(cls, init, repr, eq, order, unsafe_hash,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/vid2seq/lib/python3.11/dataclasses.py", line 958, in _process_class
cls_fields.append(_get_field(cls, name, type, kw_only))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/vid2seq/lib/python3.11/dataclasses.py", line 815, in _get_field
raise ValueError(f'mutable default {type(f.default)} for field '
ValueError: mutable default <class 'seqio.vocabularies.PassThroughVocabulary'> for field vocabulary is not allowed: use default_factory

Add method to directly add tasks/mixtures.

Currently MixtureRegistry and TaskRegistry have an add method that takes the arguments to construct a Mixture / Task. This does not seem to play well if one wants to subclass Mixture or Task with a class that takes different arguments and add them to the registries. Concrete example of a subclass of Task: turning a Mixture back into a task so that it looks "atomic" when one tries to add it back into a Mixture. Would it be possible to have a method to add to the Registrie(s) directly an object, without having to pass
the arguments that the constructor uses?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.