Git Product home page Git Product logo

google / paxml Goto Github PK

View Code? Open in Web Editor NEW
371.0 15.0 52.0 4.24 MB

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.

License: Apache License 2.0

Starlark 4.17% Python 86.98% Dockerfile 0.41% Shell 1.18% Jupyter Notebook 7.26%
c4 jax large-language-models llm model-flops parallelism gpt

paxml's Introduction

Paxml (aka Pax)

Pax is a framework to configure and run machine learning experiments on top of Jax.

Quickstart

Setting up a Cloud TPU VM

We refer to this page for more exhaustive documentation about starting a Cloud TPU project. The following command is sufficient to create a Cloud TPU VM with 8 cores from a corp machine.

export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-8
export TPU_NAME=paxml

#create a TPU VM
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --version=$VERSION \
--project=$PROJECT \
--accelerator-type=$ACCELERATOR

If you are using TPU Pod slices, please refer to this guide. Run all the commands from a local machine using gcloud with the --worker=all option:

gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE \
--worker=all --command="<commmands>"

The following quickstart sections assume you run on a single-host TPU, so you can ssh to the VM and run the commands there.

gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Installing Pax

After ssh-ing the VM, you can install the paxml stable release from PyPI, or the dev version from github.

For installing the stable release from PyPI (https://pypi.org/project/paxml/):

python3 -m pip install -U pip
python3 -m pip install paxml jax[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html

If you encounter issues with transitive dependencies and you are using the native Cloud TPU VM environment, please navigate to the corresponding release branch rX.Y.Z and download paxml/pip_package/requirements.txt. This file includes the exact versions of all transitive dependencies needed in the native Cloud TPU VM environment, in which we build/test the corresponding release.

git clone -b rX.Y.Z https://github.com/google/paxml
pip install --no-deps -r paxml/paxml/pip_package/requirements.txt

For installing the dev version from github, and for the ease of editing code:

# install the dev version of praxis first
git clone https://github.com/google/praxis
pip install -e praxis
git clone https://github.com/google/paxml
pip install -e paxml
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Run a test model

# example model using pjit (SPMD)
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \
--job_log_dir=gs://<your-bucket>

# example model using pmap
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps \
--job_log_dir=gs://<your-bucket> \
--pmap_use_tensorstore=True

Documentations

Please visit our docs folder for documentations and Jupyter Notebook tutorials. Please see the following section for instructions of running Jupyter Notebooks on a Cloud TPU VM.

Run a notebook

You can run the example notebooks in the TPU VM in which you just installed paxml. ####Steps to enable a notebook in a v4-8

  1. ssh in TPU VM with port forwarding gcloud compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_NAME --zone=$ZONE --ssh-flag="-4 -L 8080:localhost:8080"

  2. install jupyter notebook on the TPU vm and downgrade markupsafe

pip install notebook
pip install markupsafe==2.0.1
  1. export jupyter path export PATH=/home/$USER/.local/bin:$PATH

  2. scp the example notebooks to your TPU VM gcloud compute tpus tpu-vm scp $TPU_NAME:<path inside TPU> <local path of the notebooks> --zone=$ZONE --project=$PROJECT

  3. start jupyter notebook from the TPU VM and note the token generated by jupyter notebook jupyter notebook --no-browser --port=8080

  4. then in your local browser go to: http://localhost:8080/ and enter the token provided

Note: In case you need to start using a second notebook while the first notebook is still occupying the TPUs, you can run pkill -9 python3 to free up the TPUs.

Run on GPU

Note: NVIDIA has released an updated version of Pax with H100 FP8 support and broad GPU performance improvements. Please visit the NVIDIA Rosetta repository for more details and usage instructions.

Run PGLE workflow on GPU

The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time of compute and collectives, the the profile information is fed back into XLA compiler for a better scheduling decision.

The workflow to use the Profile Guided Latency Estimator workflow in XLA/GPU is:

    1. Run your workload once, with async collectives and latency hiding scheduler enabled.

You could do so by setting:

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true"
    1. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file.
import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler

# Define your profile directory
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)

# run your workflow
# for i in range(10):
#   train_step()

# Stop trace
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)

# Post process the profile
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))

# Save the profile proto to a file.
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)

After this step, you will get a profile.pb file under the rundir printed in the code.

    1. Run the workload again feeding that file into the compilation.

You need to pass the profile.pb file to the --xla_gpu_pgle_profile_file_or_directory_path flag.

 export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"

To enable logging in the XLA and check if the profile is good, set the logging level to include INFO:

export TF_CPP_MIN_LOG_LEVEL=0

Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler:

2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator

FAQs

  1. Pax runs on Jax, you can find details on running Jax jobs on Cloud TPU here, also you can find details on running Jax jobs on a Cloud TPU pod here

  2. If you run into dependency errors, please refer to the requirements.txt file in the branch corresponding to the stable release you are installing. For e.g., for the stable release 0.4.0 use branch r0.4.0 and refer to the requirements.txt for the exact versions of the dependencies used for the stable release.

Example Convergence Runs

Here are some sample convergence runs on c4 dataset.

1B model on c4 dataset

You can run a 1B params model on c4 dataset on TPU v4-8using the config C4Spmd1BAdam4Replicasfrom c4.py as follows:

python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \
--job_log_dir=gs://<your-bucket>

You can observe loss curve and log perplexity graph as follows:

16B model on c4 dataset

You can run a 16B params model on c4 dataset on TPU v4-64using the config C4Spmd16BAdam32Replicasfrom c4.py as follows:

python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd16BAdam32Replicas \
--job_log_dir=gs://<your-bucket>

You can observe loss curve and log perplexity graph as follows:

GPT3-XL model on c4 dataset

You can run the GPT3-XL model on c4 dataset on TPU v4-128using the config C4SpmdPipelineGpt3SmallAdam64Replicasfrom c4.py as follows:

python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4SpmdPipelineGpt3SmallAdam64Replicas \
--job_log_dir=gs://<your-bucket>

You can observe loss curve and log perplexity graph as follows:

Benchmark on Cloud TPU v4

The PaLM paper introduced an efficiency metric called Model FLOPs Utilization (MFU). This is measured as the ratio of the observed throughput (in, for example, tokens per second for a language model) to the theoretical maximum throughput of a system harnessing 100% of peak FLOPs. It differs from other ways of measuring compute utilization because it doesn’t include FLOPs spent on activation rematerialization during the backward pass, meaning that efficiency as measured by MFU translates directly into end-to-end training speed.

To evaluate the MFU of a key class of workloads on TPU v4 Pods with Pax, we carried out an in-depth benchmark campaign on a series of decoder-only Transformer language model (GPT) configurations that range in size from billions to trillions of parameters on the c4 dataset. The following graph shows the training efficiency using the "weak scaling" pattern where we grew the model size in proportion to the number of chips used.

Pax on Multislice

The multislice configs in this repo refer to 1. Singlie slice configs for syntax / model architecture and 2. MaxText repo for config values.

We provide example runs under c4_multislice.py` as a starting point for Pax on multislice.

Setting up Cloud TPU VMs using Queued Resources

We refer to this page for more exhaustive documentation about using Queued Resources for a multi-slice Cloud TPU project. The following shows the steps needed to set up TPUs for running example configs in this repo.

export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-128 # or v4-384 depending on which config you run

Say, for running C4Spmd22BAdam2xv4_128 on 2 slices of v4-128, you'd need to set up TPUs the following way:

export TPU_PREFIX=<your-prefix> # New TPUs will be created based off this prefix
export QR_ID=$TPU_PREFIX
export NODE_COUNT=<number-of-slices> # 1, 2, or 4 depending on which config you run


#create a TPU VM
gcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=$ACCELERATOR --runtime-version=tpu-vm-v4-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX

Installing Pax

The setup commands described earlier need to be run on ALL workers in ALL slices. You can 1) ssh into each worker and each slice individually; or 2) use for loop with --worker=all flag as the following command.

for ((i=0; i<$NODE_COUNT; i++))
do
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-$i --zone=us-central2-b --worker=all --command="pip install paxml && pip install orbax==0.1.1 && pip install \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
done

Run a test multislice model

In order to run the multislice configs, open the same number of terminals as your $NODE_COUNT. For our experiments on 2 slices(C4Spmd22BAdam2xv4_128), open two terminals. Then, run each of these commands individually from each terminal.

From Terminal 0, run training command for slice 0 as follows:

export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-0 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS JAX_USE_PJRT_C_API_ON_TPU=1 \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"

From Terminal 1, concurrently run training command for slice 1 as follows:

export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-1 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS JAX_USE_PJRT_C_API_ON_TPU=1 \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"

MaxText to Pax

This table covers details on how the MaxText variable names have been translated to Pax.

Note that MaxText has a "scale" which is multiplied to several parameters (base_num_decoder_layers, base_emb_dim, base_mlp_dim, base_num_heads) for final values.

Another thing to mention is while Pax covers DCN and ICN MESH_SHAPE as an array, in MaxText there are separate variables of data_parallelism, fsdp_parallelism and tensor_parallelism for DCN and ICI. Since these values are set as 1 by default, only the variables with value greater than 1 are recorded in this translation table.

That is, ICI_MESH_SHAPE = [ici_data_parallelism, ici_fsdp_parallelism, ici_tensor_parallelism] and DCN_MESH_SHAPE = [dcn_data_parallelism, dcn_fsdp_parallelism, dcn_tensor_parallelism]

Pax C4Spmd22BAdam2xv4_128 MaxText 2xv4-128.sh (after scale is applied)
scale (applied to next 4 variables) 3
NUM_LAYERS 48 base_num_decoder_layers 16 48
MODEL_DIMS 6144 base_emb_dim 2048 6144
HIDDEN_DIMS 24576 MODEL_DIMS * 4 (= base_mlp_dim) 8192 24576
NUM_HEADS 24 base_num_heads 8 24
DIMS_PER_HEAD 256 head_dim 256
PERCORE_BATCH_SIZE 16 per_device_batch_size 16
MAX_SEQ_LEN 1024 max_target_length 1024
VOCAB_SIZE 32768 vocab_size 32768
FPROP_DTYPE jnp.bfloat16 dtype bfloat16
USE_REPEATED_LAYER TRUE
SUMMARY_INTERVAL_STEPS 10
ICI_MESH_SHAPE [1, 64, 1] ici_fsdp_parallelism 64
DCN_MESH_SHAPE [2, 1, 1] dcn_data_parallelism 2

Data inputs

Intro

Input is an instance of the BaseInput class for getting data into model for train/eval/decode.

class BaseInput:

  def get_next(self):
    pass

  def reset(self):
    pass

It acts like an iterator: get_next() returns a NestedMap, where each field is a numerical array with batch size as its leading dimension.

Each input is configured by a subclass of BaseInput.HParams. In this page, we use p to denote an instance of a BaseInput.Params, and it instantiates to input.

Multihost infeed

In Pax, data is always multihost: Each Jax process will have a separate, independent input instantiated. Their params will have different p.infeed_host_index, set automatically by Pax.

Hence, the local batch size seen on each host is p.batch_size, and the global batch size is (p.batch_size * p.num_infeed_hosts). One will often see p.batch_size set to jax.local_device_count() * PERCORE_BATCH_SIZE.

Due to this multihost nature, input must be sharded properly.

For training, each input must never emit identical batches, and for eval on a finite dataset, each input must terminate after the same number of batches. The best solution is to have the input implementation properly shard the data, such that each input on different hosts do not overlap. Failing that, one can also use different random seed to avoid duplicate batches during training.

Input for eval data

input.reset() is never called on training data, but it can for eval (or decode) data.

For each eval (or decode) run, Pax fetches N batches from input by calling input.get_next() N times. The number of batches used, N, can be a fixed number specified by user, via p.eval_loop_num_batches; or N can be dynamic (p.eval_loop_num_batches=None) i.e. we call input.get_next() until we exhaust all of its data (by raising StopIteration or tf.errors.OutOfRange).

If p.reset_for_eval=True, p.eval_loop_num_batches is ignored and N is determined dynamically as the number of batches to exhaust the data. In this case, p.repeat should be set to False, as doing otherwise would lead to infinite decode/eval.

If p.reset_for_eval=False, Pax will fetch p.eval_loop_num_batches batches. This should be set with p.repeat=True so that data are not prematurely exhausted.

Note that LingvoEvalAdaptor inputs require p.reset_for_eval=True.

N: static N: dynamic
p.reset_for_eval=True Each eval run uses the One epoch per eval run.
: : first N batches. Not : eval_loop_num_batches :
: : supported yet. : is ignored. Input must :
: : : be finite :
: : : (p.repeat=False) :
p.reset_for_eval=False Each eval run uses Not supported.
: : non-overlapping N : :
: : batches on a rolling : :
: : basis, according to : :
: : eval_loop_num_batches : :
: : . Input must repeat : :
: : indefinitely : :
: : (p.repeat=True) or : :
: : otherwise may raise : :
: : exception : :

If running decode/eval on exactly one epoch (i.e. when p.reset_for_eval=True), the input must handle sharding correctly such that each shard raises at the same step after exactly the same number of batches are produced. This usually means that the input must pad the eval data. This is done automatically bySeqIOInput and LingvoEvalAdaptor (see more below).

Eval metrics

For the majority of inputs, we only ever call get_next() on them to get batches of data. One type of eval data is an exception to this, where "how to compute metrics" is also defined on the input object as well.

This is only supported with SeqIOInput that defines some canonical eval benchmark. Specifically, Pax uses predict_metric_fns and score_metric_fns() defined on the SeqIO task to compute eval metrics (although Pax does not depend on SeqIO evaluator directly).

Best practices

When a model uses multiple inputs, either between train/eval or different training data between pretraining/finetuning, users must ensure that the tokenizers used by the inputs are identical, especially when importing different inputs implemented by others.

Users can sanity check the tokenizers by decoding some ids with input.ids_to_strings().

It's always a good idea to sanity check the data by looking at a few batches. Users can easily reproduce the param in a colab and inspect the data:

p = ... # specify the intended input param
inp = p.Instantiate()
b = inp.get_next()
print(b)

Training data typically should not use a fixed random seed. This is because if the training job is preempted, training data will start to repeat itself. In particular, for Lingvo inputs, we recommend setting p.input.file_random_seed = 0 for training data.

To test for whether sharding is handled correctly, users can manually set different values for p.num_infeed_hosts, p.infeed_host_index and see whether the instantiated inputs emit different batches.

Input types

Pax supports 3 types of inputs: SeqIO, Lingvo, and custom.

SeqIO

SeqIOInput can be used to import datasets.

SeqIO inputs handle correct sharding and padding of eval data automatically.

Lingvo

LingvoInputAdaptor can be used to import datasets.

The input is fully delegated to the Lingvo implementation, which may or may not handle sharding automatically.

For GenericInput based Lingvo input implementation using a fixed packing_factor, we recommend to use LingvoInputAdaptorNewBatchSize to specify a bigger batch size for the inner Lingvo input and put the desired (usually much smaller) batch size on p.batch_size.

For eval data, we recommend using LingvoEvalAdaptor to handle sharding and padding for running eval over one epoch.

Custom

Custom subclass of BaseInput. Users implement their own subclass, typically with tf.data or SeqIO.

Users can also inherit an existing input class to only customize post processing of batches. For example:

class MyInput(base_input.LingvoInputAdaptor):

  def get_next(self):
    batch = super().get_next()
    # modify batch: batch.new_field = ...
    return batch

Key Pax components

Hyperparameters

Hyperparameters are an important part of defining models and configuring experiments.

To integrate better with Python tooling, Pax/Praxis uses a pythonic dataclass based configuration style for hyperparameters.

class Linear(base_layer.BaseLayer):
  """Linear layer without bias."""

  class HParams(BaseHParams):
    """Associated hyperparams for this layer class.

    Attributes:
      input_dims: Depth of the input.
      output_dims: Depth of the output.
    """
    input_dims: int = 0
    output_dims: int = 0

Nesting

It's also possible to nest HParams dataclasses, in the example below, the linear_tpl attribute is a nested Linear.HParams.

class FeedForward(base_layer.BaseLayer):
  """Feedforward layer with activation."""

  class HParams(BaseHParams):
    """Associated hyperparams for this layer class.

    Attributes:
      input_dims: Depth of the input.
      output_dims: Depth of the output.
      has_bias: Adds bias weights or not.
      linear_tpl: Linear layer params.
      activation_tpl: Activation layer params.
    """
    input_dims: int = 0
    output_dims: int = 0
    has_bias: bool = True
    linear_tpl: BaseHParams = sub_config_field(Linear.HParams)
    activation_tpl: activations.BaseActivation.HParams = sub_config_field(
        ReLU.HParams)

Layers

A Layer represents an arbitrary function possibly with trainable parameters. A Layer can contain other Layers as children. Layers are the essential building blocks of models. Layers inherit from the Flax nn.Module.

Typically layers define two methods:

setup

This method creates trainable weights and child layers.

fprop

This method defines the forward propagation function, computing some output based on the inputs. Additionally, fprop might add summaries or track auxiliary losses.

Fiddle and Shared layers

Fiddle is an open-sourced Python-first configuration library designed for ML applications. Pax/Praxis supports interoperability with Fiddle Config/Partial(s) and some advanced features like eager error checking and shared parameters.

fdl_config = Linear.HParams.config(input_dims=1, output_dims=1)

# A typo.
fdl_config.input_dimz = 31337  # Raises an exception immediately to catch typos fast!


fdl_partial = Linear.HParams.partial(input_dims=1)

Using Fiddle, layers can be configured to be shared (eg: instantiated only once with shared trainable weights).

Model

A model defines solely the network, typically a collection of Layers and defines interfaces for interacting with the model such as decoding, etc.

Some example base models include:

  • LanguageModel
  • SequenceModel
  • ClassificationModel

Task

A Task contains one more more Models and Learner/Optimizers. The simplest Task subclass is a SingleTask which requires the following Hparams:

  class HParams(base_task.BaseTask.HParams):
    """Task parameters.

    Attributes:
      name: Name of this task object, must be a valid identifier.
      model: The underlying JAX model encapsulating all the layers.
      train: HParams to control how this task should be trained.
      metrics: A BaseMetrics aggregator class to determine how metrics are
         computed.
      loss_aggregator: A LossAggregator aggregator class to derermine how the
        losses are aggregated (e.g single or MultiLoss)
      vn: HParams to control variational noise.

Releases

PyPI Version Commit
0.1.0 546370f5323ef8b27d38ddc32445d7d3d1e4da9a
Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

paxml's People

Contributors

a9isha avatar aaroey avatar aaudiber avatar ashishenoyp avatar ashors1 avatar bignamehyp avatar cpgaffney1 avatar daiyip avatar dhr avatar edloper avatar hawkinsp avatar ishark avatar jysohn23 avatar kaixih avatar laurentes avatar panzhufeng avatar pluskid avatar ppwwyyxx avatar protoget avatar pschuh avatar rchen152 avatar rhofour avatar saeta avatar schien1729 avatar sgpyc avatar ukoxyz avatar wangpengmit avatar yashk2810 avatar zhangqiaorjc avatar zhangyujing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

paxml's Issues

Installing paxml from source failed due to dependency problem

This is using the latest development version (see full log here.):

#8 79.88 + git clone https://github.com/google/paxml.git /opt/paxml
#8 79.89 Cloning into '/opt/paxml'...
#8 80.87 + pushd /opt/paxml
#8 80.87 + git checkout HEAD
#8 80.87 /opt/paxml /
#8 80.89 Your branch is up to date with 'origin/main'.
#8 80.89 + pip install -e '.[gpu]'
...
#8 94.75 ERROR: Cannot install paxml and paxml[gpu]==1.0.0 because these package versions have conflicting dependencies.
#8 94.75 
#8 94.75 The conflict is caused by:
#8 94.75     paxml[gpu] 1.0.0 depends on tensorflow~=2.9.2
#8 94.75     tensorflow-text 2.9.0 depends on tensorflow<2.10 and >=2.9.0; platform_machine != "arm64" or platform_system != "Darwin"
#8 94.75     lingvo 0.12.1 depends on tensorflow==2.9

Use bfloat16 for eval

I'm running paxml on an Intel Xeon CPU server using the paxml/main.py program. I'm trying to create a model that creates weights in bfloat16, and uses that datatype during eval. I modified the LmCloudSpmd2B configuration with the following lines:

MODEL_DTYPE = jnp.bfloat16
ICI_MESH_SHAPE = [1, 1, 1]

The training status output includes the following output.

model.dtype : type/jax.numpy/float32
model.fprop_dtype : dtype[bfloat16]

All of the other operator datatypes are float32. When I run that model with the --eval switch all of the computation is in float32. How can I direct paxml to use bfloat16?

Tom

Pipeline Parallelism: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size Fatal Python error: Aborted

Hello!

I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I run into some errors when NUM_MICROBATCHES > 1.

System:

8X NVIDIA A100-SXM 80 GB

Gin Configs:

from __gin__ import dynamic_registration

import __main__ as train_script
from paxml import gin_utils
from paxml.tasks.lm import model_params_with_gin
from paxml.tasks.lm.params import datasets_gin
from praxis import optimizers
from praxis import schedules
from praxis.layers import activations
from praxis.layers import repeats
from jax import numpy as jnp

MAX_SL=2048
SUMMARY_INTERVAL_STEPS=100
CHECKPOINT_EVERY_N_STEPS=1000
EVAL_INTERVAL_STEPS=100
MAX_STEPS=600000
NUM_STAGES = 4
ICI_MESH_SHAPE=[%NUM_STAGES, 1, 1, 2]
PERCORE_BATCH_SIZE = 2

MODEL = @model_params_with_gin.TransformerLmSpmdPipeline()
model_params_with_gin.TransformerLmSpmdPipeline:
  USE_REPEATED_LAYER = False
  MAX_SEQ_LEN = %MAX_SL
  NUM_LAYERS = 12
  NUM_HEADS = 12
  MODEL_DIMS = 768
  HIDDEN_DIMS = 3072
  DIMS_PER_HEAD = 64
  VOCAB_SIZE = 51200
  TRAINABLE_POSITION_EMB = True
  TRAINABLE_PE_MAX_SEQ_LEN = %MAX_SL
  ACTIVATION_CLS = @activations.GELU.HParams()
  PACKED_INPUT = True
  USE_BIAS = False
  MAX_STEPS=%MAX_STEPS
  INIT_STD = 0.023
  EVAL_INTERVAL_STEPS = 100
  NUM_STAGES = %NUM_STAGES
  NUM_MICROBATCHES = 2
  ICI_MESH_SHAPE = %ICI_MESH_SHAPE
  FPROP_DTYPE = @jnp.bfloat16
  SUMMARY_INTERVAL_STEPS=%SUMMARY_INTERVAL_STEPS
  CHECKPOINT_EVERY_N_STEPS=%CHECKPOINT_EVERY_N_STEPS
  EVAL_INTERVAL_STEPS=%EVAL_INTERVAL_STEPS

OPTIMIZER = @optimizers.Adam.HParams()
optimizers.Adam.HParams:
  beta1 = 0.9
  beta2 = 0.95
  learning_rate = 6e-4
  epsilon_root = 0.0
  epsilon = 1e-8
  weight_decay = 0.1
  clip_threshold = 1.0
  clip_gradient_norm_to_value = 5.0


SCHEDULER = @schedules.LinearRampupCosineDecay.HParams()
schedules.LinearRampupCosineDecay.HParams:
  warmup_steps = 636
  decay_start = 637
  decay_end = 500000
  min_ratio = 0.1
  max = 1.0

DATASET = @datasets_gin.PileUnsupervisedDataset()
datasets_gin.PileUnsupervisedDataset:
  MAX_SEQ_LEN = %MAX_SL
  PERCORE_BATCH_SIZE = %PERCORE_BATCH_SIZE

## experiment == model + dataset
EXPERIMENT = @model_params_with_gin.Experiment()
model_params_with_gin.Experiment:
  model = %MODEL
  dataset = %DATASET
  optimizer = %OPTIMIZER
  scheduler = %SCHEDULER
  
train_script.run:
  experiment_config = %EXPERIMENT

Command:

#! /bin/bash

set -x

PYTHONPATH=/pax/paxml:/pax/praxis python3 /pax/paxml/paxml/main.py \
    --exp=tasks.lm.params.c4.PileSpmdAdam \
    --gin_file="/pax/paxml/configs/gpt3_126_pp.gin" \
    --tfds_data_dir="/pax/datasets" \
    --vocab_path='/pax/vocab/c4_en_301_5Mexp2_spm.model' \
    --pmap_use_tensorstore=True \
    --job_log_dir=/logs/ \
    --alsologtostderr 

set +x

XLA Complie Time Error:

2022-10-10 16:01:05.537760: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size 
Fatal Python error: Aborted

Current thread 0x00007f5c10b73740 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 940 in backend_compile
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294 in wrapper
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 996 in compile_or_get_cached
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 3048 in from_hlo
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2890 in compile
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 815 in _pjit_call_impl
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 685 in process_primitive
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 327 in bind_with_trace
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 324 in bind
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 385 in wrapped
  File "/pax/paxml/paxml/train.py", line 1087 in train_and_evaluate_spmd_model
  File "/pax/paxml/paxml/train.py", line 271 in train_and_evaluate
  File "/pax/paxml/paxml/main.py", line 290 in run_experiment
  File "/pax/paxml/paxml/main.py", line 535 in run
  File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582 in gin_wrapper
  File "/pax/paxml/paxml/main.py", line 588 in main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251 in _run_main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303 in run
  File "/pax/paxml/paxml/main.py", line 631 in <module>

There is no problem when NUM_MICROBATCHES = 1.

It would be great if someone could look into this to figure out what may be causing XLA to break when using NUM_MICROBATCHES > 1.

Pipeline Parallelism: USE_REPEATED_LAYERS bug

Hello!

I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I notice that USE_REPEATED_LAYERS=True helps speed up compilation and also reduces the memory requirement. However, when I set USE_REPEATED_LAYERS=True with Pipeline Parallelism, I get the following error.

System:

8X NVIDIA A100-SXM 80 GB

Gin Configs:

from __gin__ import dynamic_registration

import __main__ as train_script
from paxml import gin_utils
from paxml.tasks.lm import model_params_with_gin
from paxml.tasks.lm.params import datasets_gin
from praxis import optimizers
from praxis import schedules
from praxis.layers import activations
from praxis.layers import repeats
from jax import numpy as jnp

MAX_SL=2048
SUMMARY_INTERVAL_STEPS=100
CHECKPOINT_EVERY_N_STEPS=1000
EVAL_INTERVAL_STEPS=100
MAX_STEPS=600000
NUM_STAGES = 4
ICI_MESH_SHAPE=[%NUM_STAGES, 1, 1, 2]
PERCORE_BATCH_SIZE = 2

MODEL = @model_params_with_gin.TransformerLmSpmdPipeline()
model_params_with_gin.TransformerLmSpmdPipeline:
  USE_REPEATED_LAYER = True
  MAX_SEQ_LEN = %MAX_SL
  NUM_LAYERS = 12
  NUM_HEADS = 12
  MODEL_DIMS = 768
  HIDDEN_DIMS = 3072
  DIMS_PER_HEAD = 64
  VOCAB_SIZE = 51200
  TRAINABLE_POSITION_EMB = True
  TRAINABLE_PE_MAX_SEQ_LEN = %MAX_SL
  ACTIVATION_CLS = @activations.GELU.HParams()
  PACKED_INPUT = True
  USE_BIAS = False
  MAX_STEPS=%MAX_STEPS
  INIT_STD = 0.023
  EVAL_INTERVAL_STEPS = 100
  NUM_STAGES = %NUM_STAGES
  NUM_MICROBATCHES = 1
  ICI_MESH_SHAPE = %ICI_MESH_SHAPE
  FPROP_DTYPE = @jnp.bfloat16
  SUMMARY_INTERVAL_STEPS=%SUMMARY_INTERVAL_STEPS
  CHECKPOINT_EVERY_N_STEPS=%CHECKPOINT_EVERY_N_STEPS
  EVAL_INTERVAL_STEPS=%EVAL_INTERVAL_STEPS

OPTIMIZER = @optimizers.Adam.HParams()
optimizers.Adam.HParams:
  beta1 = 0.9
  beta2 = 0.95
  learning_rate = 6e-4
  epsilon_root = 0.0
  epsilon = 1e-8
  weight_decay = 0.1
  clip_threshold = 1.0
  clip_gradient_norm_to_value = 5.0


SCHEDULER = @schedules.LinearRampupCosineDecay.HParams()
schedules.LinearRampupCosineDecay.HParams:
  warmup_steps = 636
  decay_start = 637
  decay_end = 500000
  min_ratio = 0.1
  max = 1.0

DATASET = @datasets_gin.PileUnsupervisedDataset()
datasets_gin.PileUnsupervisedDataset:
  MAX_SEQ_LEN = %MAX_SL
  PERCORE_BATCH_SIZE = %PERCORE_BATCH_SIZE

## experiment == model + dataset
EXPERIMENT = @model_params_with_gin.Experiment()
model_params_with_gin.Experiment:
  model = %MODEL
  dataset = %DATASET
  optimizer = %OPTIMIZER
  scheduler = %SCHEDULER
  
train_script.run:
  experiment_config = %EXPERIMENT

Command:

#! /bin/bash

set -x

PYTHONPATH=/pax/paxml:/pax/praxis python3 /pax/paxml/paxml/main.py \
    --exp=tasks.lm.params.c4.PileSpmdAdam \
    --gin_file="/pax/paxml/configs/gpt3_126_pp.gin" \
    --tfds_data_dir="/pax/datasets" \
    --vocab_path='/pax/vocab/c4_en_301_5Mexp2_spm.model' \
    --pmap_use_tensorstore=True \
    --job_log_dir=/logs/ \
    --alsologtostderr 

set +x

Error:

Traceback (most recent call last):
  File "/pax/paxml/paxml/main.py", line 631, in <module>
    app.run(main, flags_parser=_gin_flags_parser)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/pax/paxml/paxml/main.py", line 588, in main
    run_with_gin()
  File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/local/lib/python3.8/dist-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/pax/paxml/paxml/main.py", line 535, in run
    run_experiment(
  File "/pax/paxml/paxml/main.py", line 290, in run_experiment
    train.train_and_evaluate(
  File "/pax/paxml/paxml/train.py", line 271, in train_and_evaluate
    train_and_evaluate_spmd_model(task_p, train_input_p, job_log_dir,
  File "/pax/paxml/paxml/train.py", line 851, in train_and_evaluate_spmd_model
    vars_weight_params = jax_task.model.abstract_init_with_metadata(
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/transforms.py", line 1320, in wrapped_fn
    return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 353, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 652, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/pax/praxis/praxis/base_layer.py", line 1231, in abstract_init_with_metadata
    variables_abstract = jax.eval_shape(init_fn, rngs)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 3024, in eval_shape
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 662, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1929, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1946, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/layers/pipeline.py", line 217, in force_init
    body_init_fn(self.body, None)
  File "/pax/praxis/praxis/layers/pipeline.py", line 162, in fn
    model.force_init(None)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/pax/praxis/praxis/base_layer.py", line 1169, in force_init
    jax.tree_map(force, val)
  File "/pax/praxis/praxis/base_layer.py", line 1167, in force
    v.force_init(*args)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
TypeError: force_init() takes 1 positional argument but 2 were given
  In call to configurable 'run' (<function run at 0x7fedd131ab80>)

Would you have any suggestions on how to fix this?

Int8 checkpoint

Hi, is there a way to save a quantized int8 checkpoint? Looks like right now the checkpoint is in fp32.

[Question] Very low MFU(30%~35%) when train bf16 Llama2 and GPT model with single SXM4 A100 machine.

I don't know what happened, is the calculation precision and parameter precision not set correctly? Deepspeed or Megatron could achieve 55% MFU easily with same machine.
Here is my bash script:

#! /bin/bash
set -u
set -o pipefail

TFDS_DATA_DIR=$1
VOCAB_PATH=$2
PREC=${3:-"bfloat16"}        # Precision (float32, bfloat16)
NUM_GPUS=${4:-8}      # Number of GPUs (1, 2, 4, 8)
PERCORE_BATCH_SIZE=${5:-4}
LOG_DIR=${6:-"test_logdir"}

export VOCAB_PATH=$VOCAB_PATH

BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
                       --xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_async_all_gather=true
                       --xla_gpu_enable_async_reduce_scatter=true  --xla_gpu_enable_highest_priority_async_stream=true
                       --xla_gpu_enable_triton_softmax_fusion=false  --xla_gpu_all_reduce_combine_threshold_bytes=51200
                       --xla_gpu_graph_level=3 --xla_gpu_enable_async_all_reduce=true
                       --xla_gpu_enable_async_collectives=true --xla_gpu_enable_async_collective_permute=true
                       --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
                       --xla_gpu_enable_async_all_to_all=true --xla_gpu_all_reduce_contiguous=true
                       --xla_gpu_all_reduce_blueconnect_num_devices_per_host=true
                       --xla_gpu_enable_cudnn_frontend=true --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true
                       --xla_gpu_enable_cudnn_layer_norm "}
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"

export ENABLE_TE=1

mkdir -p ${LOG_DIR}
python3 -u -m paxml.main \
    --job_log_dir=${LOG_DIR} \
    --fdl_config=paxml.tasks.lm.params.nvidia.Llama2_7B \
    --fdl.FPROP_DTYPE=\"${PREC}\" \
    --fdl.ICI_MESH_SHAPE="[1,$(expr ${NUM_GPUS}), 1]" \
    --fdl.DCN_MESH_SHAPE="[1,1,1]" \
    --fdl.NUM_STAGES=1 \
    --fdl.MICROBATCH_SIZE=$PERCORE_BATCH_SIZE \
    --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \
    --tfds_data_dir=$TFDS_DATA_DIR \
    --alsologtostderr \
    2>&1 | tee ${LOG_DIR}/llama2_7B_output.log

EXP_STATUS=$?

if [ $EXP_STATUS != 0 ]; then
  echo "Run failed"
else
  echo "Run succeeded!"
fi

According https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax, Nvidia train a 5B GPT model with Nativ BF16 in 256 A100 GPU. And its performance 465.45 Sequences/Sec when sequences global batch size is 8*256=2048. So it means it costed 4.4s per step. Am I correct?
This script could calculate its MFU which is 38.958427%. It's too low!

# Nvidia Jax GPT5B
card_num=256
gbs=8*card_num
layers=24
num_query=32
num_heads=32
enc_seq_len=2048
hs=4096
ffn_hs=16384
vocab=50304

sequences_per_sec=465.45
seconds_per_step=gbs/sequences_per_sec


#Model total parameters:
params_qkv_state = (1+2*(num_query/num_heads))*hs*hs
params_post_attention_linear = hs*hs
params_fead_forward_network = 2*hs*ffn_hs
params_vocabulary_embedding = hs*vocab


#FPROP:
qkv_state = gbs*2*(1+2*(num_query/num_heads))*enc_seq_len*hs*hs
attention_matrix_computation = gbs*2*enc_seq_len*enc_seq_len*hs
attention_over_values = gbs*2*enc_seq_len*enc_seq_len*hs
post_attention_linear_projection = gbs*2*enc_seq_len*hs*hs
fead_forward_network = gbs*(2*2*enc_seq_len*ffn_hs*hs)
vocabulary_embedding = gbs*2*enc_seq_len*hs*vocab

#BPROP:
#FPROP*2

model_params = (params_qkv_state+params_post_attention_linear+params_fead_forward_network)*layers + params_vocabulary_embedding 
model_float = 3*((qkv_state+attention_matrix_computation+attention_over_values+post_attention_linear_projection+fead_forward_network)*layers + vocabulary_embedding) 
model_flops = model_float/seconds_per_step
cluster_ideal_flops = 312*(10**12) * card_num
MFU = model_flops/cluster_ideal_flops
print("Model parameters {:4f}B MFU={:4f}%".format(model_params/(10**9),MFU*100))

Error running Common Crawl example

Sorry to interrupt! When running

python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \
--job_log_dir=gs://<your-bucket> 

in the examples, I encountered the following error seeming to suggest I cannot load from the bucket provided in c4.py

Traceback (most recent call last):
  File ".local/lib/python3.8/site-packages/paxml/main.py", line 407, in <module>
    app.run(main, flags_parser=absl_flags.flags_parser)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File ".local/lib/python3.8/site-packages/paxml/main.py", line 382, in main
    run(experiment_config=experiment_config,
  File ".local/lib/python3.8/site-packages/paxml/main.py", line 336, in run
    search_space = tuning_lib.get_search_space(experiment_config)
  File "/home/robertli/.local/lib/python3.8/site-packages/paxml/tuning_lib.py", line 81, in get_search_space
    search_space = pg.hyper.trace(inspect_search_space, require_hyper_name=True)
  File "/home/robertli/.local/lib/python3.8/site-packages/pyglove/core/hyper/dynamic_evaluation.py", line 586, in trace
    fun()
  File "/home/robertli/.local/lib/python3.8/site-packages/paxml/tuning_lib.py", line 77, in inspect_search_space
    _ = instantiate(d)
  File "/home/robertli/.local/lib/python3.8/site-packages/praxis/base_hyperparams.py", line 1103, in instantiate
    return config.Instantiate(**kwargs)
  File "/home/robertli/.local/lib/python3.8/site-packages/praxis/base_hyperparams.py", line 601, in Instantiate
    return self.cls(self, **kwargs)
  File "/home/robertli/.local/lib/python3.8/site-packages/paxml/seqio_input.py", line 443, in __init__
    self._dataset = self._get_dataset()
  File "/home/robertli/.local/lib/python3.8/site-packages/paxml/seqio_input.py", line 551, in _get_dataset
    ds = self._get_backing_ds(
  File "/home/robertli/.local/lib/python3.8/site-packages/paxml/seqio_input.py", line 686, in _get_backing_ds
    ds = self.mixture_or_task.get_dataset(
  File "/home/robertli/.local/lib/python3.8/site-packages/seqio/dataset_providers.py", line 1205, in get_dataset
    len(self.source.list_shards(split=split)) >= shard_info.num_shards)
  File "/home/robertli/.local/lib/python3.8/site-packages/seqio/dataset_providers.py", line 455, in list_shards
    return [_get_filename(info) for info in self.tfds_dataset.files(split)]
  File "/home/robertli/.local/lib/python3.8/site-packages/seqio/utils.py", line 152, in files
    split_info = self.builder.info.splits[split]
  File "/home/robertli/.local/lib/python3.8/site-packages/seqio/utils.py", line 129, in builder
    LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = tfds.builder(
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow_datasets/core/logging/__init__.py", line 169, in __call__
    return function(*args, **kwargs)
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow_datasets/core/load.py", line 202, in builder
    return read_only_builder.builder_from_files(str(name), **builder_kwargs)
  File "/usr/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow_datasets/core/read_only_builder.py", line 259, in builder_from_files
    builder_dir = _find_builder_dir(name, **builder_kwargs)
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow_datasets/core/read_only_builder.py", line 327, in _find_builder_dir
    builder_dir = _find_builder_dir_single_dir(
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow_datasets/core/read_only_builder.py", line 417, in _find_builder_dir_single_dir
    found_version_str = _get_version_str(
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow_datasets/core/read_only_builder.py", line 484, in _get_version_str
    all_versions = version_lib.list_all_versions(os.fspath(builder_dir))
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow_datasets/core/utils/version.py", line 193, in list_all_versions
    if not root_dir.exists():
  File "/home/robertli/.local/lib/python3.8/site-packages/etils/epath/gpath.py", line 130, in exists
    return self._backend.exists(self._path_str)
  File "/home/robertli/.local/lib/python3.8/site-packages/etils/epath/backend.py", line 204, in exists
    return self.gfile.exists(path)
  File "/home/robertli/.local/lib/python3.8/site-packages/tensorflow/python/lib/io/file_io.py", line 288, in file_exists_v2
    _pywrap_file_io.FileExists(compat.path_to_bytes(path))
tensorflow.python.framework.errors_impl.PermissionDeniedError: Error executing an HTTP request: HTTP response code 403 with body '{
  "error": {
    "code": 403,
    "message": "[email protected] does not have storage.objects.get access to the Google Cloud Storage object. Permission 'storage.objects.get' denied on resource (or it may not exist).",
    "errors": [
      {
        "message": "[email protected] does not have storage.objects.get access to the Google Cloud Storage object. Permission 'storage.objects.get' denied on resource (or it may not exist)."'
	 when reading metadata of gs://mlperf-llm-public2/c4/en

I wonder if this is because I haven't configured something correctly, because the bucket seems like a public one.

I tried using the TFDS default bucket (gs://tfds-data/datasets) instead of gs://mlperf-llm-public2 and this problem doesn't arise, but it requires me to choose among available versions of c4 (not 3.0.4). Even then, I cannot proceed because it gives me some other error.

Thanks in advance for your attention and help!

ERROR: error loading package 'paxml'

I'm trying to run the PAX code in this repo:

I installed the prerequisites as mentioned in the repo:

python3 -m pip install -U pip
python3 -m pip install paxml praxis
cd paxml
bazel run -c opt --define=pax_task=lm \
    main -- \
    --exp=lm.decoder.ptb.PTBCharTransformerSmallSgd \
    --job_log_dir=/tmp/jax_log_dir/exp01 --alsologtostderr

To start training, I got the command line from here: https://github.com/google/paxml/blob/main/paxml/main.py#L19-L22

bazel run -c opt \
  third_party/py/paxml/tasks/lm/params:main -- \
  --exp=bert.BertAdamL4H128 \
  --job_log_dir=/tmp/jax_log_dir/exp01 --alsologtostderr

I'm encountering the following issue:

ERROR: Skipping 'paxml': error loading package 'paxml': Every .bzl file must have a corresponding package, but '//praxis:build-visibility.bzl' does not have one. Please create a BUILD file in the same or any parent directory. Note that this
BUILD file does not need to do anything except exist.
WARNING: Target pattern parsing failed.
ERROR: error loading package 'paxml': Every .bzl file must have a corresponding package, but '//praxis:build-visibility.bzl' does not have one. Please create a BUILD file in the same or any parent directory. Note that this BUILD file does not need to do anything except exist.
INFO: Elapsed time: 0.750s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (0 packages loaded)
FAILED: Build did NOT complete successfully (0 packages loaded)
    currently loading: paxml
    Fetching @com_google_protobuf; Restarting.

Though I see the BUILD is present here: https://github.com/google/paxml/blob/main/paxml/BUILD

Is there anything I'm doing wrong? Any suggestions to get it running?

ARM64 Build

I've been trying to install PAXML on Ubuntu 22.04 ARM64 but I seem to stuck in getting lingvo (mandatory dependency?) running there: I've been struggling to find a recipe for this. Has this been done? Any documentation about it?

Unexpected Overheads with Activation Checkpointing with Pipeline Parallelism

We notice a buggy behavior with bitcasts and dynamic update slices. When we turn on activation checkpointing (e.g., saving outputs of projection layers using the SAVE_OUT_PROJ flag in PAXML) we see multiple extra updates and copies.

For example, we want to checkpoint an activation of shape [2,2048,48,128]. However, in the HLO below we see that the copies are of shape [15,1,2,2048,48,128]. Here, 15 is the number of microbatches we are using with pipeline parallelism.

Snippet of HLO:

fusion.549 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, ..., kind=kLoop, calls=fused_computation.549, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
get-tuple-element.5874 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=0
copy.583 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5874)
get-tuple-element.5866 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=1
copy.575 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5866)
get-tuple-element.5868 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=2
copy.577 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5868)
get-tuple-element.5870 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=3
copy.579 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5870)
get-tuple-element.5872 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=4
copy.581 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5872)

...

fused_computation.549 {
  param_1.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(1)
  bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_1.8511)
  param_0.6313 = bf16[2,48,128,2048]{3,2,1,0} parameter(0)
  bitcast.52600 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.6313)
  param_2.5901 = s32[] parameter(2)
  constant_7564 = s32[] constant(0)
  compare.3477 = pred[] compare(param_2.5901, constant_7564), direction=LT, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/pipeline._scan_fn/pipeline._get_iteration_inputs/jit(remainder)/rem" source_file="/pax/praxis/praxis/layers/pipeline.py" source_line=422}
  constant_11524 = s32[] constant(15)
  add.6580 = s32[] add(param_2.5901, constant_11524), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/add" source_file="/pax/praxis/praxis/base_layer.py" source_line=695}
  select.5360 = s32[] select(compare.3477, add.6580, param_2.5901), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/select_n" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  dynamic-update-slice.325 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52601, bitcast.52600, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  bitcast.52599 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(dynamic-update-slice.325), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  param_4.7770 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(4)
  bitcast.52617.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_4.7770)
  param_3.8428 = bf16[2,48,128,2048]{3,2,1,0} parameter(3)
  bitcast.52616.clone.1 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_3.8428)
  dynamic-update-slice.333.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52617.clone.1, bitcast.52616.clone.1, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  ...
  ROOT tuple.356 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}) tuple(bitcast.52599, bitcast.52615.clone.1, bitcast.52611.clone.1, bitcast.52607.clone.1, bitcast.52603.clone.1)
}

It seems like there is a big buffer of size [15,1,2,2048,48,128] holding the activations for all microbatches. Within each microbatch, we are trying to update one row of this buffer (of shape [2,2048,48,128]). But XLA loads the entire buffer into memory, performs the update, and then copies the buffer back. We see this problem in our profiles. The amount of time spent on D2D copies (i.e., copy.575 to copy.583) is much larger than expected for the amount of data that should be copied. Right now, the time spent on activation checkpointing is 5% to 8% of the overall run time for a GPT-3 style model.

Our current understanding: The reason for the copy is because when bitcast is treated as computing a new value (e.g., like a convert or sqrt), then a new tensor must be used in each loop iteration, therefore a copy of each DUS result must be made. This should be able to be fixed by treating bitcast as an aliasing operation instead of computing a new value --- in the dataflow analysis. I think there is an option in dataflow analysis that configures how bitcast should be treated. In XLA TPU, the option is set to be true where bitcasts are treated as simply an aliasing operation.

Would someone be able to look into this?

I am attaching a link to the HLO: https://drive.google.com/drive/folders/1fYUsqfDgYRRpgOklE-k7qx_5ixkJzKPD?usp=sharing

Jax + tpu and AQT int8 train model loss is abnormal

I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing?
ps: I train model on jax==0.4.23 and tpu v5p-8

In other words, is there a training example for AQT int8 in pax?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.