Git Product home page Git Product logo

orbax's Introduction

Orbax

Orbax is a namespace providing common utility libraries for JAX users.

Checkpointing

pip install orbax-checkpoint (latest PyPi release) OR

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint' (from this repository, at HEAD)

import orbax.checkpoint

Orbax includes a checkpointing library oriented towards JAX users, supporting a variety of different features required by different frameworks, including asynchronous checkpointing, various types, and various storage formats. We aim to provide a highly customizable and composable API which maximizes flexibility for diverse use cases.

Note

Please see Announcements for important updates.

Exporting

pip install orbax-export (latest PyPi release) OR

pip install 'git+https://github.com/google/orbax/#subdirectory=export' (from this repository, at HEAD)

import orbax.export

Orbax also includes a serialization library for JAX users, enabling the exporting of JAX models to the TensorFlow SavedModel format.

Note that orbax-export requires TensorFlow, but does not include it by default to allow for flexibility in version choice. If you wish to install with standard TensorFlow, please use pip install orbax-export[all].

Support

Contact [email protected] for help or with any questions about Orbax!

History

Orbax was initially published as a catch-all package itself. In order to minimize dependency bloat for users, we have frozen that package at orbax-0.1.6, and will continue to release future changes under the domain-specific utilities detailed above (e.g. orbax-checkpoint).

As we have preserved the orbax namespace, existing import statements can remain unchanged (e.g. from orbax import checkpoint).

orbax's People

Contributors

aaroey avatar agesmundo avatar ajain-23 avatar changlan avatar chromehearts avatar cky9301 avatar conchylicultor avatar cpgaffney1 avatar crazydonkey200 avatar dicentra13 avatar dubey avatar ethanluoyc avatar faizan-m avatar gnecula avatar haoyuz avatar hawkinsp avatar hyeontaek avatar ivyzx avatar jpuigcerver avatar k-w-w avatar laurentes avatar liangyaning33 avatar marksandler2 avatar marvin182 avatar maxwillzq avatar niketkumar avatar rohan-anil avatar voutcn avatar yashk2810 avatar zhangqiaorjc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

orbax's Issues

Issue loading checkpoint step during Paxml evaluation

I am attempting to run evaluation after training a model in Paxml. When trying to locate the step of the latest checkpoint, I see this error:

 File "/opt/paxml/paxml/eval_lib.py", line 482, in evaluate
    _common_eval_or_decode_loop(
 File "/opt/paxml/paxml/eval_lib.py", line 817, in _common_eval_or_decode_loop
    last_checkpoint_step = int(
ValueError: invalid literal for int() with base 10: 'PLACEHOLDER://step'

I have tracked down the issue in Paxml to this commit, but we believe this failure is a result of an underlying Orbax bug.

Command that fails in python 3.8

I have a command that runs successfully on TPU and fails on CPU. The CPU machine has more RAM so I don't think thats the issue.
Same orbax version = 2.3.1

If I don't set concurrent_gb=100 I get

ValueError: Requested more bytes than we reserved space for: 96636764160 > 96000000000

in both environments.

so I set it to 100 and then get

RuntimeError: Task <Task pending name='Task-23' coro=<async_deserialize.<locals>.cb() running at /home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py:261> cb=[gather.<locals>._done_callback() at /home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/tasks.py:769]> got Future <Future pending> attached to a different loop

on CPU but success on TPU.

Any idea?

Code it's hitting

def create_orbax_checkpoint_manager(p):
		h = checkpoint.PyTreeCheckpointHandler(concurrent_gb=100)
		mstate_ckptr = Checkpointer(h) # also fails w AsyncCheckpointer
		mngr = CheckpointManager(
		    p,
		    checkpointers={
		        "model_state": mstate_ckptr,
		        "dataloader_state": Checkpointer(DataloaderHandler()),
		        "config": Checkpointer(TrainConfigHandler()),
		        "train_rng": Checkpointer(DataloaderHandler()),
		    },
		    options=CheckpointManagerOptions(create=True, cleanup_tmp_directories=True),
		)
		return mngr


    restore_args = jax.tree_util.tree_map(
        map_to_pspec, unboxed_train_state, state_mesh_annotations
    )

        ckpt_manager = create_orbax_checkpoint_manager(
            load_parameters_path, 
        )
        logger.info(f"restoring state from {load_parameters_path=}")
        items = {MS: {PARAMS: unboxed_train_state.params}}
        kw = {MS: {"restore_args": {PARAMS: restore_args.params}}}
        params = ckpt_manager.restore(step=1, items=items, restore_kwargs=kw)

Command

(On TPU i just remove the env vars)

TPU_ACCELERATOR_TYPE='' JAX_PLATFORM='cpu' XLA_FLAGS="--xla_force_host_platform_device_count=4" CUDA_VISIBLE_DEVICES=""  JAX_CACHE_DIR=$HOME/.jax_cache python \
  eval.py -m exp.run_name=108b_heather_B16_dringus model=108b  \
exp.load_parameters_path=/mnt/resource_nvme/heather.0623.scale108b.jax.bf16.noep  \ exp.checkpoint_dir=/mnt/resource_nvme/heather.0623.scale108b.jax.sharded.noep "$@"
  "$@"

Traceback

Traceback (most recent call last):
  File "eval.py", line 279, in eval_main
    eval_loop(cfg)
  File "eval.py", line 166, in eval_loop
    (state, _, state_mesh_annotations, _,) = max_utils.setup_initial_state(
  File "/home/sam/character-tech/maxtext/MaxText/max_utils.py", line 424, in setup_initial_state
    ) = checkpointing.load_state_if_possible(
  File "/home/sam/character-tech/maxtext/MaxText/checkpointing.py", line 361, in load_state_if_possible
    params = ckpt_manager.restore(step=1, items=items, restore_kwargs=kw)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpoint_manager.py", line 565, in restore
    restored_items = self._restore_impl(
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpoint_manager.py", line 597, in _restore_impl
    restored[item_name] = self._checkpointers[item_name].restore(
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/checkpointer.py", line 97, in restore
    restored = self._handler.restore(directory, *args, item=item, **kwargs)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 569, in restore
    restored_item = asyncio.run(_restore())
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
    return future.result()
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 566, in _restore
    flat = await asyncio.gather(*flat)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/lazy_utils.py", line 48, in maybe_get_async
    return await value.get_async(*args, **kwargs)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/lazy_utils.py", line 30, in get_async
    return await self._get_fn(*args, **kwargs)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 448, in _deserialize
    return await handler.deserialize(info, args)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/type_handlers.py", line 541, in deserialize
    return await serialization.async_deserialize(
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 287, in async_deserialize
    return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 55, in create_async_array_from_callback
    dbs = await asyncio.gather(*future_arrays)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 261, in cb
    await byte_limiter.wait_for_bytes(requested_bytes)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py", line 130, in wait_for_bytes
    await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/locks.py", line 400, in wait_for
    await self.wait()
  File "/home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/locks.py", line 373, in wait
    await fut
RuntimeError: Task <Task pending name='Task-23' coro=<async_deserialize.<locals>.cb() running at /home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/jax/experimental/array_serialization/serialization.py:261> cb=[gather.<locals>._done_callback() at /home/sam/.conda/envs/char-latest/lib/python3.8/asyncio/tasks.py:769]> got Future <Future pending> attached to a different loop

Async checkpointer failure handling

We've observed the following behavior of orbax-checkpoint:

  • If async checkpoint saving fails, the main thread continues working.
    • In case that's relevant, checkpoint saving failed in our case because we were writing to a slow network disk and hit a barrier timeout in multi-host training setup.
  • During saving of the next checkpoint, orbax tries to remove the previous checkpoint but can't find it because it wasn't successfully created and, thus, fails.

Is there a way to handle this situation more gracefully? For instance, have a way to fail training immediately when async saving fails.

Drop unused dependencies `cached_property` and `importlib_resources`

I recently noticed that orbax-checkpoint depends on the cached-property package and importlib-resources backport, and wanted to suggest using the standard library instead (respectively @functools.cached_property in Python 3.8+ and the importlib.resources module in Python 3.7+; and note that Python 3.7 is end-of-life next week).

But then when I did a quick search to check how much work this would be... it looks like orbax-checkpoint literally doesn't use these packages at all. So dropping them should be quite trivial, and allow for somewhat smaller virtual environments 😁

Dependency array-record Not Available on ARM64 Architecture

Issue

I am currently using Orbax for checkpointing in my project, which is being developed on an ARM64 architecture (Apple Silicon). Orbax has a transitive dependency on array-record through etils. However, array-record does not provide a version compatible with ARM64, leading to installation issues when using Poetry.

Environment

Orbax version: [0.1.9]
Python version: [3.9.13]
Operating System: macOS on ARM64 (Apple Silicon)
Dependency management: Poetry

Steps to Reproduce

  1. Include orbax as a dependency in a Python project managed with Poetry.
  2. Run poetry install on an ARM64 Mac.
  β€’ Installing array-record (0.5.0): Failed

  RuntimeError

  Unable to find installation candidates for array-record (0.5.0)

  at ~/Library/Application Support/pypoetry/venv/lib/python3.8/site-packages/poetry/installation/chooser.py:73 in choose_for
       69β”‚
       70β”‚             links.append(link)
       71β”‚
       72β”‚         if not links:
    β†’  73β”‚             raise RuntimeError(f"Unable to find installation candidates for {package}")
       74β”‚
       75β”‚         # Get the best link
       76β”‚         chosen = max(links, key=lambda link: self._sort_key(package, link))
       77β”‚

Cannot install array-record.
  1. Observe that the installation fails due to array-record not being available for ARM64.

Connection between Orbax and array-record

> poetry show array-record --tree
array-record 0.5.0 A file format that achieves a new frontier of IO efficiency
β”œβ”€β”€ absl-py *
└── etils *
    β”œβ”€β”€ absl-py *
    β”œβ”€β”€ fsspec *
    β”œβ”€β”€ importlib-resources *
    β”‚   └── zipp >=3.1.0
    β”œβ”€β”€ numpy *
    β”œβ”€β”€ tqdm *
    β”‚   └── colorama *
    β”œβ”€β”€ typing-extensions *
    └── zipp * (circular dependency aborted here)
> poetry show orbax --tree
orbax 0.1.9 Orbax
└── orbax-checkpoint >=0.1.8
    β”œβ”€β”€ absl-py *
    β”œβ”€β”€ etils *
    β”‚   β”œβ”€β”€ absl-py * (circular dependency aborted here)
    β”‚   β”œβ”€β”€ fsspec *
    β”‚   β”œβ”€β”€ importlib-resources *
    β”‚   β”‚   └── zipp >=3.1.0
    β”‚   β”œβ”€β”€ numpy *
    β”‚   β”œβ”€β”€ tqdm *
    β”‚   β”‚   └── colorama *
    β”‚   β”œβ”€β”€ typing-extensions *
    β”‚   └── zipp * (circular dependency aborted here)
    β”œβ”€β”€ jax >=0.4.9
    β”‚   β”œβ”€β”€ importlib-metadata >=4.6
    β”‚   β”‚   └── zipp >=0.5 (circular dependency aborted here)
    β”‚   β”œβ”€β”€ ml-dtypes >=0.1.0
    β”‚   β”‚   β”œβ”€β”€ numpy >=1.21.2 (circular dependency aborted here)
    β”‚   β”‚   └── numpy >1.20 (circular dependency aborted here)
    β”‚   β”œβ”€β”€ numpy >=1.21 (circular dependency aborted here)
    β”‚   β”œβ”€β”€ opt-einsum *
    β”‚   β”‚   └── numpy >=1.7 (circular dependency aborted here)
    β”‚   └── scipy >=1.7
    β”‚       └── numpy >=1.21.6,<1.28.0 (circular dependency aborted here)
    β”œβ”€β”€ jaxlib *
    β”‚   β”œβ”€β”€ ml-dtypes >=0.1.0 (circular dependency aborted here)
    β”‚   β”œβ”€β”€ numpy >=1.21 (circular dependency aborted here)
    β”‚   └── scipy >=1.7 (circular dependency aborted here)
    β”œβ”€β”€ msgpack *
    β”œβ”€β”€ nest-asyncio *
    β”œβ”€β”€ numpy * (circular dependency aborted here)
    β”œβ”€β”€ protobuf *
    β”œβ”€β”€ pyyaml *
    β”œβ”€β”€ tensorstore >=0.1.35
    β”‚   └── numpy >=1.16.0 (circular dependency aborted here)
    └── typing-extensions * (circular dependency aborted here)

Extra

The issue lies within etils using array-record and array-record doesnt have a distribution for the ARM64 architecture so it fails,
https://pypi.org/project/array-record/#files

Poetry

Here is my .toml for reproducability:

[tool.poetry.dependencies]
python = ">=3.9.0,<=3.9.18"
jaxtyping = "^0.2.11"

[tool.poetry.group.mltools]
optional = true

[tool.poetry.group.mltools.dependencies]
numpy = "^1.23.1"
scipy = "^1.9.0"
einops = "^0.5.0"
hydra-core = "^1.2.0"
omegaconf = "^2.2.3"
wandb = "^0.13.5"

[tool.poetry.group.dataset]
optional = true

[tool.poetry.group.dataset.dependencies]
tensorflow-macos = {version = "^2.12.0", platform = "darwin"}
tensorflow-datasets = "^4.7.0"

[tool.poetry.group.torch]
optional = true

[tool.poetry.group.torch.dependencies]
torch = "^1.13.1"
torchvision = "^0.14.1"
functorch = "^1.13.1"

[tool.poetry.group.jax]
optional = true

[tool.poetry.group.jax.dependencies]
jax-metal = { version = "^0.0.4", markers = "platform_machine == 'arm64'" }
flax = "^0.5.2"
optax = "^0.1.3"
orbax = "^0.1.9"

[tool.poetry.group.jupyter]
optional = true

[tool.poetry.group.jupyter.dependencies]
notebook = "^6.4.12"
jupyter = "^1.0.0"
ipykernel = "^6.15.1"
ipython = "^8.4.0"
requests = "^2.31.0"

[tool.poetry.group.additional]
optional = true

[tool.poetry.group.additional.dependencies]
black = {extras = ["jupyter"], version = "^22.6.0"}
pre-commit = "^2.20.0"
pytest = "^7.1.3"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Incorrect null check in pytree_checkpoint_handler.py

orbax/checkpoint/pytree_checkpoint_handler.py:661 has the following check: if not item

It most likely should be if item is None, as otherwise this check will raise an error when item is an array (which is a valid pytree according to pytree definition).

New PyPI release

Current jax and flax version do not work with the current released version of orbax due to the deprecation of jax.experimental.gda_serialization that was fixed here df7810f.

A new release to PyPI would be welcomed.

Thanks :)

Sharding checkpoints?

Currently looking into orbax thanks to the flax migration, but it seems like there's not a good sharding option available from what I can see in the documentation.

However, I'd love to be proven wrong. Currently, attempting to save via Orbax results in a ram overflow on one of my devices, so being able to move the model to CPU, shard, and save the shards individually would be a major boon.

Maybe this is a potential feature to include? Or maybe I'm reading the documentation wrong, and you're already ahead of me?

Cannot install orbax-checkpoint due to uvicorn error

I was trying to rebuild my HuggingFace space when I encountered a really annoying error. I started to debug it and concluded that the problem is in orbax and its dependencies (or, less likely, in pip itself). My only row in requirements.txt is:

orbax-checkpoint

Feel free to play with the demo:
https://huggingface.co/spaces/NickKolok/converter-dup

The full error logs are:

``` ===== Build Queued at 2023-09-18 11:46:40 / Commit SHA: a7d6680 =====

--> FROM docker.io/library/python:3.10@sha256:1a8dcc07368065c2b285b24236a98e80355d6de7f685b416407aafb6dd32529f
DONE 0.0s

--> RUN useradd -m -u 1000 user
CACHED

--> RUN --mount=target=/root/packages.txt,source=packages.txt apt-get update && xargs -r -a /root/packages.txt apt-get install -y && rm -rf /var/lib/apt/lists/*
CACHED

--> RUN pip install --no-cache-dir pip==22.3.1 && pip install --no-cache-dir datasets "huggingface-hub>=0.12.1" "protobuf<4" "click<8.1" "pydantic~=1.0"
CACHED

--> WORKDIR /home/user/app
CACHED

--> RUN apt-get update && apt-get install -y git git-lfs ffmpeg libsm6 libxext6 cmake libgl1-mesa-glx && rm -rf /var/lib/apt/lists/* && git lfs install
CACHED

--> RUN --mount=target=pre-requirements.txt,source=pre-requirements.txt pip install --no-cache-dir -r pre-requirements.txt
CACHED

--> Restoring cache
DONE 8.2s

--> RUN --mount=target=requirements.txt,source=requirements.txt pip install --no-cache-dir -r requirements.txt
Defaulting to user installation because normal site-packages is not writeable
Collecting orbax-checkpoint
Downloading orbax_checkpoint-0.3.5-py3-none-any.whl (100 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.6/100.6 kB 7.2 MB/s eta 0:00:00
Collecting nest_asyncio
Downloading nest_asyncio-1.5.8-py3-none-any.whl (5.3 kB)
Requirement already satisfied: protobuf in /home/user/.local/lib/python3.10/site-packages (from orbax-checkpoint->-r requirements.txt (line 1)) (3.20.3)
Collecting etils[epath,epy]
Downloading etils-1.4.1-py3-none-any.whl (135 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 135.8/135.8 kB 57.7 MB/s eta 0:00:00
Collecting jaxlib
Downloading jaxlib-0.4.14-cp310-cp310-manylinux2014_x86_64.whl (73.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 73.7/73.7 MB 158.9 MB/s eta 0:00:00
Collecting tensorstore>=0.1.35
Downloading tensorstore-0.1.43-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.4/13.4 MB 232.4 MB/s eta 0:00:00
Requirement already satisfied: numpy in /home/user/.local/lib/python3.10/site-packages (from orbax-checkpoint->-r requirements.txt (line 1)) (1.25.2)
Collecting msgpack
Downloading msgpack-1.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (316 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 316.8/316.8 kB 461.4 MB/s eta 0:00:00
Collecting jax>=0.4.9
Downloading jax-0.4.14.tar.gz (1.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 442.8 MB/s eta 0:00:00
Installing build dependencies: started
Installing build dependencies: finished with status 'done'
Getting requirements to build wheel: started
Getting requirements to build wheel: finished with status 'done'
Preparing metadata (pyproject.toml): started
Preparing metadata (pyproject.toml): finished with status 'done'
Requirement already satisfied: typing_extensions in /home/user/.local/lib/python3.10/site-packages (from orbax-checkpoint->-r requirements.txt (line 1)) (4.7.1)
Collecting absl-py
Downloading absl_py-1.4.0-py3-none-any.whl (126 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 126.5/126.5 kB 403.4 MB/s eta 0:00:00
Requirement already satisfied: pyyaml in /home/user/.local/lib/python3.10/site-packages (from orbax-checkpoint->-r requirements.txt (line 1)) (6.0.1)
Collecting ml-dtypes>=0.2.0
Downloading ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 448.5 MB/s eta 0:00:00
Collecting opt-einsum
Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.5/65.5 kB 395.2 MB/s eta 0:00:00
Collecting scipy>=1.7
Downloading scipy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 36.3/36.3 MB 228.8 MB/s eta 0:00:00
Collecting importlib_resources
Downloading importlib_resources-6.0.1-py3-none-any.whl (34 kB)
Collecting zipp
Downloading zipp-3.16.2-py3-none-any.whl (7.2 kB)
Building wheels for collected packages: jax
Building wheel for jax (pyproject.toml): started
Building wheel for jax (pyproject.toml): finished with status 'done'
Created wheel for jax: filename=jax-0.4.14-py3-none-any.whl size=1535361 sha256=b82b3f01640aa7abd8dcfb90e521a5a04a3522eb9cbe55f4643472b8f8e9d216
Stored in directory: /tmp/pip-ephem-wheel-cache-5c326x1p/wheels/62/5e/1f/647158ef39dccf1b59baa0d9bf09fd25286398236bb5208194
Successfully built jax
Installing collected packages: msgpack, zipp, tensorstore, scipy, opt-einsum, nest_asyncio, ml-dtypes, importlib_resources, etils, absl-py, jaxlib, jax, orbax-checkpoint
Successfully installed absl-py-1.4.0 etils-1.4.1 importlib_resources-6.0.1 jax-0.4.14 jaxlib-0.4.14 ml-dtypes-0.2.0 msgpack-1.0.5 nest_asyncio-1.5.8 opt-einsum-3.3.0 orbax-checkpoint-0.3.5 scipy-1.11.2 tensorstore-0.1.43 zipp-3.16.2

[notice] A new release of pip available: 22.3.1 -> 23.2.1
[notice] To update, run: python -m pip install --upgrade pip
DONE 13.3s

--> RUN pip install --no-cache-dir gradio[oauth]==3.10.1 spaces==0.14.0
Defaulting to user installation because normal site-packages is not writeable
Collecting gradio[oauth]==3.10.1
Downloading gradio-3.10.1-py3-none-any.whl (11.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.6/11.6 MB 269.2 MB/s eta 0:00:00
Collecting spaces==0.14.0
Downloading spaces-0.14.0-py3-none-any.whl (9.4 kB)
WARNING: gradio 3.10.1 does not provide the extra 'oauth'
Requirement already satisfied: pandas in /home/user/.local/lib/python3.10/site-packages (from gradio[oauth]==3.10.1) (2.1.0)
Collecting websockets>=10.0
Downloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.9/129.9 kB 428.2 MB/s eta 0:00:00
Collecting markdown-it-py[linkify,plugins]
Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 87.5/87.5 kB 387.3 MB/s eta 0:00:00
Collecting httpx
Downloading httpx-0.25.0-py3-none-any.whl (75 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 75.7/75.7 kB 382.1 MB/s eta 0:00:00
Collecting orjson
Downloading orjson-3.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (138 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.7/138.7 kB 425.6 MB/s eta 0:00:00
Collecting jinja2
Downloading Jinja2-3.1.2-py3-none-any.whl (133 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.1/133.1 kB 385.9 MB/s eta 0:00:00
Collecting matplotlib
Downloading matplotlib-3.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.6/11.6 MB 289.6 MB/s eta 0:00:00
Requirement already satisfied: fsspec in /home/user/.local/lib/python3.10/site-packages (from gradio[oauth]==3.10.1) (2023.6.0)
Requirement already satisfied: pyyaml in /home/user/.local/lib/python3.10/site-packages (from gradio[oauth]==3.10.1) (6.0.1)
Collecting pydub
Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Collecting python-multipart
Downloading python_multipart-0.0.6-py3-none-any.whl (45 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.7/45.7 kB 257.0 MB/s eta 0:00:00
Collecting fastapi
Downloading fastapi-0.103.1-py3-none-any.whl (66 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.2/66.2 kB 344.1 MB/s eta 0:00:00
Collecting h11<0.13,>=0.11
Downloading h11-0.12.0-py3-none-any.whl (54 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.9/54.9 kB 328.5 MB/s eta 0:00:00
Collecting ffmpy
Downloading ffmpy-0.3.1.tar.gz (5.5 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Requirement already satisfied: numpy in /home/user/.local/lib/python3.10/site-packages (from gradio[oauth]==3.10.1) (1.25.2)
Collecting paramiko
Downloading paramiko-3.3.1-py3-none-any.whl (224 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.8/224.8 kB 435.7 MB/s eta 0:00:00
Collecting pillow
Downloading Pillow-10.0.1-cp310-cp310-manylinux_2_28_x86_64.whl (3.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 281.8 MB/s eta 0:00:00
Requirement already satisfied: aiohttp in /home/user/.local/lib/python3.10/site-packages (from gradio[oauth]==3.10.1) (3.8.5)
Requirement already satisfied: pydantic in /home/user/.local/lib/python3.10/site-packages (from gradio[oauth]==3.10.1) (1.10.12)
Requirement already satisfied: requests in /home/user/.local/lib/python3.10/site-packages (from gradio[oauth]==3.10.1) (2.31.0)
Collecting uvicorn
Downloading uvicorn-0.23.2-py3-none-any.whl (59 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.5/59.5 kB 335.2 MB/s eta 0:00:00
Collecting pycryptodome
Downloading pycryptodome-3.19.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 295.7 MB/s eta 0:00:00
Collecting psutil<6,>=2
Downloading psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (282 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 282.1/282.1 kB 441.1 MB/s eta 0:00:00
Requirement already satisfied: typing-extensions<5,>=4 in /home/user/.local/lib/python3.10/site-packages (from spaces==0.14.0) (4.7.1)
Requirement already satisfied: idna<4,>=2.5 in /home/user/.local/lib/python3.10/site-packages (from requests->gradio[oauth]==3.10.1) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /home/user/.local/lib/python3.10/site-packages (from requests->gradio[oauth]==3.10.1) (2023.7.22)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/user/.local/lib/python3.10/site-packages (from requests->gradio[oauth]==3.10.1) (3.2.0)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/user/.local/lib/python3.10/site-packages (from requests->gradio[oauth]==3.10.1) (2.0.4)
Requirement already satisfied: aiosignal>=1.1.2 in /home/user/.local/lib/python3.10/site-packages (from aiohttp->gradio[oauth]==3.10.1) (1.3.1)
Requirement already satisfied: frozenlist>=1.1.1 in /home/user/.local/lib/python3.10/site-packages (from aiohttp->gradio[oauth]==3.10.1) (1.4.0)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/user/.local/lib/python3.10/site-packages (from aiohttp->gradio[oauth]==3.10.1) (4.0.3)
Requirement already satisfied: multidict<7.0,>=4.5 in /home/user/.local/lib/python3.10/site-packages (from aiohttp->gradio[oauth]==3.10.1) (6.0.4)
Requirement already satisfied: yarl<2.0,>=1.0 in /home/user/.local/lib/python3.10/site-packages (from aiohttp->gradio[oauth]==3.10.1) (1.9.2)
Requirement already satisfied: attrs>=17.3.0 in /home/user/.local/lib/python3.10/site-packages (from aiohttp->gradio[oauth]==3.10.1) (23.1.0)
Collecting anyio<4.0.0,>=3.7.1
Downloading anyio-3.7.1-py3-none-any.whl (80 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 80.9/80.9 kB 392.3 MB/s eta 0:00:00
Collecting starlette<0.28.0,>=0.27.0
Downloading starlette-0.27.0-py3-none-any.whl (66 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.0/67.0 kB 373.5 MB/s eta 0:00:00
Collecting httpcore<0.19.0,>=0.18.0
Downloading httpcore-0.18.0-py3-none-any.whl (76 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.0/76.0 kB 373.8 MB/s eta 0:00:00
Collecting sniffio
Downloading sniffio-1.3.0-py3-none-any.whl (10 kB)
Collecting MarkupSafe>=2.0
Downloading MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25 kB)
Collecting mdurl~=0.1
Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Collecting mdit-py-plugins
Downloading mdit_py_plugins-0.4.0-py3-none-any.whl (54 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.1/54.1 kB 349.8 MB/s eta 0:00:00
Collecting linkify-it-py<3,>=1
Downloading linkify_it_py-2.0.2-py3-none-any.whl (19 kB)
Collecting contourpy>=1.0.1
Downloading contourpy-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (301 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 301.7/301.7 kB 437.8 MB/s eta 0:00:00
Collecting cycler>=0.10
Downloading cycler-0.11.0-py3-none-any.whl (6.4 kB)
Requirement already satisfied: packaging>=20.0 in /home/user/.local/lib/python3.10/site-packages (from matplotlib->gradio[oauth]==3.10.1) (23.1)
Collecting pyparsing>=2.3.1
Downloading pyparsing-3.1.1-py3-none-any.whl (103 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 103.1/103.1 kB 374.6 MB/s eta 0:00:00
Collecting fonttools>=4.22.0
Downloading fonttools-4.42.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 267.9 MB/s eta 0:00:00
Collecting kiwisolver>=1.0.1
Downloading kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 447.2 MB/s eta 0:00:00
Requirement already satisfied: python-dateutil>=2.7 in /home/user/.local/lib/python3.10/site-packages (from matplotlib->gradio[oauth]==3.10.1) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /home/user/.local/lib/python3.10/site-packages (from pandas->gradio[oauth]==3.10.1) (2023.3.post1)
Requirement already satisfied: tzdata>=2022.1 in /home/user/.local/lib/python3.10/site-packages (from pandas->gradio[oauth]==3.10.1) (2023.3)
Collecting pynacl>=1.5
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 434.7 MB/s eta 0:00:00
Collecting cryptography>=3.3
Downloading cryptography-41.0.3-cp37-abi3-manylinux_2_28_x86_64.whl (4.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.3/4.3 MB 462.4 MB/s eta 0:00:00
Collecting bcrypt>=3.2
Downloading bcrypt-4.0.1-cp36-abi3-manylinux_2_28_x86_64.whl (593 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 593.7/593.7 kB 454.5 MB/s eta 0:00:00
Requirement already satisfied: click>=7.0 in /home/user/.local/lib/python3.10/site-packages (from uvicorn->gradio[oauth]==3.10.1) (8.0.4)
Collecting exceptiongroup
Downloading exceptiongroup-1.1.3-py3-none-any.whl (14 kB)
Collecting cffi>=1.12
Downloading cffi-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (441 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 441.8/441.8 kB 440.0 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of frozenlist to determine which version is compatible with other requirements. This could take a while.
Collecting frozenlist>=1.1.1
Downloading frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (225 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 225.7/225.7 kB 416.4 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of fonttools to determine which version is compatible with other requirements. This could take a while.
Collecting fonttools>=4.22.0
Downloading fonttools-4.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 240.1 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of cycler to determine which version is compatible with other requirements. This could take a while.
Collecting cycler>=0.10
Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)
INFO: pip is looking at multiple versions of cryptography to determine which version is compatible with other requirements. This could take a while.
Collecting cryptography>=3.3
Downloading cryptography-41.0.2-cp37-abi3-manylinux_2_28_x86_64.whl (4.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.3/4.3 MB 456.1 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of contourpy to determine which version is compatible with other requirements. This could take a while.
Collecting contourpy>=1.0.1
Downloading contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (300 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 300.7/300.7 kB 446.6 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of click to determine which version is compatible with other requirements. This could take a while.
Collecting click>=7.0
Downloading click-8.1.7-py3-none-any.whl (97 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 97.9/97.9 kB 378.4 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of charset-normalizer to determine which version is compatible with other requirements. This could take a while.
Collecting charset-normalizer<4,>=2
Downloading charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (201 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 201.8/201.8 kB 431.2 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of certifi to determine which version is compatible with other requirements. This could take a while.
Collecting certifi>=2017.4.17
Downloading certifi-2023.7.22-py3-none-any.whl (158 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 158.3/158.3 kB 401.4 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of bcrypt to determine which version is compatible with other requirements. This could take a while.
Collecting bcrypt>=3.2
Downloading bcrypt-4.0.0-cp36-abi3-manylinux_2_28_x86_64.whl (594 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 594.4/594.4 kB 436.5 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of attrs to determine which version is compatible with other requirements. This could take a while.
Collecting attrs>=17.3.0
Downloading attrs-23.1.0-py3-none-any.whl (61 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61.2/61.2 kB 389.8 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of async-timeout to determine which version is compatible with other requirements. This could take a while.
Collecting async-timeout<5.0,>=4.0.0a3
Downloading async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
INFO: pip is looking at multiple versions of anyio to determine which version is compatible with other requirements. This could take a while.
INFO: pip is looking at multiple versions of aiosignal to determine which version is compatible with other requirements. This could take a while.
Collecting aiosignal>=1.1.2
Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
INFO: pip is looking at multiple versions of uvicorn to determine which version is compatible with other requirements. This could take a while.
Collecting uvicorn
Downloading uvicorn-0.23.1-py3-none-any.whl (59 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.5/59.5 kB 362.9 MB/s eta 0:00:00
Downloading uvicorn-0.23.0-py3-none-any.whl (59 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.6/59.6 kB 355.6 MB/s eta 0:00:00
Downloading uvicorn-0.22.0-py3-none-any.whl (58 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.3/58.3 kB 340.0 MB/s eta 0:00:00
Downloading uvicorn-0.21.1-py3-none-any.whl (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.8/57.8 kB 333.2 MB/s eta 0:00:00
Downloading uvicorn-0.21.0-py3-none-any.whl (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.8/57.8 kB 353.1 MB/s eta 0:00:00
Downloading uvicorn-0.20.0-py3-none-any.whl (56 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.9/56.9 kB 356.0 MB/s eta 0:00:00
Downloading uvicorn-0.19.0-py3-none-any.whl (56 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.6/56.6 kB 354.7 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of uvicorn to determine which version is compatible with other requirements. This could take a while.
Downloading uvicorn-0.18.3-py3-none-any.whl (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.4/57.4 kB 353.3 MB/s eta 0:00:00
Downloading uvicorn-0.18.2-py3-none-any.whl (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.0/57.0 kB 325.8 MB/s eta 0:00:00
Downloading uvicorn-0.18.1-py3-none-any.whl (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.0/57.0 kB 357.7 MB/s eta 0:00:00
Downloading uvicorn-0.18.0-py3-none-any.whl (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.0/57.0 kB 360.9 MB/s eta 0:00:00
Downloading uvicorn-0.17.6-py3-none-any.whl (53 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 53.6/53.6 kB 339.4 MB/s eta 0:00:00
Collecting asgiref>=3.4.0
Downloading asgiref-3.7.2-py3-none-any.whl (24 kB)
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.
Collecting uvicorn
Downloading uvicorn-0.17.5-py3-none-any.whl (53 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 53.6/53.6 kB 329.7 MB/s eta 0:00:00
Downloading uvicorn-0.17.4-py3-none-any.whl (52 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 52.6/52.6 kB 307.6 MB/s eta 0:00:00
Downloading uvicorn-0.17.3-py3-none-any.whl (52 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 52.6/52.6 kB 329.3 MB/s eta 0:00:00
Downloading uvicorn-0.17.2-py3-none-any.whl (52 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 52.7/52.7 kB 335.9 MB/s eta 0:00:00
Downloading uvicorn-0.17.1-py3-none-any.whl (54 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.7/54.7 kB 344.5 MB/s eta 0:00:00
Downloading uvicorn-0.17.0.post1-py3-none-any.whl (54 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.9/54.9 kB 341.1 MB/s eta 0:00:00
Downloading uvicorn-0.16.0-py3-none-any.whl (54 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.9/54.9 kB 314.2 MB/s eta 0:00:00
Downloading uvicorn-0.15.0-py3-none-any.whl (54 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.4/54.4 kB 333.0 MB/s eta 0:00:00
Downloading uvicorn-0.14.0-py3-none-any.whl (50 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50.5/50.5 kB 338.7 MB/s eta 0:00:00
Downloading uvicorn-0.13.4-py3-none-any.whl (46 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.1/46.1 kB 321.5 MB/s eta 0:00:00
Collecting click==7.*
Downloading click-7.1.2-py2.py3-none-any.whl (82 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 82.8/82.8 kB 375.1 MB/s eta 0:00:00
Downloading click-7.1.1-py2.py3-none-any.whl (82 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 82.8/82.8 kB 355.3 MB/s eta 0:00:00
Downloading click-7.1-py2.py3-none-any.whl (82 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 82.8/82.8 kB 353.6 MB/s eta 0:00:00
Downloading Click-7.0-py2.py3-none-any.whl (81 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 81.3/81.3 kB 370.2 MB/s eta 0:00:00
Collecting uvicorn
Downloading uvicorn-0.13.3-py3-none-any.whl (45 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.7/45.7 kB 288.3 MB/s eta 0:00:00
INFO: pip is looking at multiple versions of click to determine which version is compatible with other requirements. This could take a while.
Downloading uvicorn-0.13.2-py3-none-any.whl (45 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.5/45.5 kB 300.5 MB/s eta 0:00:00
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.
Downloading uvicorn-0.13.1-py3-none-any.whl (45 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.5/45.5 kB 314.6 MB/s eta 0:00:00
Downloading uvicorn-0.13.0-py3-none-any.whl (45 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.7/45.7 kB 328.9 MB/s eta 0:00:00
Downloading uvicorn-0.12.3-py3-none-any.whl (45 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.7/45.7 kB 330.4 MB/s eta 0:00:00
Downloading uvicorn-0.12.2-py3-none-any.whl (45 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.1/45.1 kB 322.4 MB/s eta 0:00:00
Downloading uvicorn-0.12.1-py3-none-any.whl (44 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 kB 332.4 MB/s eta 0:00:00
Downloading uvicorn-0.12.0-py3-none-any.whl (44 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 kB 301.6 MB/s eta 0:00:00
Downloading uvicorn-0.11.8-py3-none-any.whl (43 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.3/43.3 kB 297.6 MB/s eta 0:00:00
Collecting httptools==0.1.*
Downloading httptools-0.1.2.tar.gz (106 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.7/106.7 kB 377.5 MB/s eta 0:00:00
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Collecting uvloop>=0.14.0
Downloading uvloop-0.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 307.3 MB/s eta 0:00:00
Collecting uvicorn
Downloading uvicorn-0.11.7-py3-none-any.whl (43 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.4/43.4 kB 253.8 MB/s eta 0:00:00
Downloading uvicorn-0.11.6-py3-none-any.whl (43 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.2/43.2 kB 325.0 MB/s eta 0:00:00
Downloading uvicorn-0.11.5-py3-none-any.whl (43 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.1/43.1 kB 327.3 MB/s eta 0:00:00
Downloading uvicorn-0.11.4-py3-none-any.whl (43 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.1/43.1 kB 321.3 MB/s eta 0:00:00
Downloading uvicorn-0.11.3-py3-none-any.whl (42 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.3/42.3 kB 310.8 MB/s eta 0:00:00
Downloading uvicorn-0.11.2-py3-none-any.whl (42 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.3/42.3 kB 321.6 MB/s eta 0:00:00
Downloading uvicorn-0.11.1-py3-none-any.whl (42 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.2/42.2 kB 317.2 MB/s eta 0:00:00
Downloading uvicorn-0.11.0-py3-none-any.whl (42 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.2/42.2 kB 312.4 MB/s eta 0:00:00
Downloading uvicorn-0.10.9-py3-none-any.whl (42 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.5/42.5 kB 312.5 MB/s eta 0:00:00
Downloading uvicorn-0.10.8.tar.gz (27 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.7.tar.gz (27 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.6.tar.gz (28 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.5.tar.gz (27 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.4.tar.gz (27 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.3.tar.gz (27 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.2.tar.gz (27 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.1.tar.gz (27 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.10.0.tar.gz (26 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.9.1.tar.gz (25 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.9.0.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.8.6.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.8.5.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.8.4.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.8.3.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.8.2.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.8.1.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.8.0.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.7.3.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.7.2.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.7.1.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.7.0.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.6.1.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.6.0.tar.gz (24 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.5.2.tar.gz (23 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.5.1.tar.gz (23 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Collecting httptools
Downloading httptools-0.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (428 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 428.8/428.8 kB 458.6 MB/s eta 0:00:00
Collecting uvicorn
Downloading uvicorn-0.5.0.tar.gz (23 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.4.6.tar.gz (22 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.4.5.tar.gz (22 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.4.4.tar.gz (22 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.4.3.tar.gz (22 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.4.2.tar.gz (21 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.4.1.tar.gz (21 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.4.0.tar.gz (21 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.32.tar.gz (21 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.31.tar.gz (21 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.30.tar.gz (21 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.29.tar.gz (21 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.28.tar.gz (22 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.27.tar.gz (22 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.26.tar.gz (22 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.25.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.24.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.23.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.22.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.21.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.20.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.19.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.18.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.17.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.16.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.15.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.14.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.13.tar.gz (20 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.12.tar.gz (19 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.11.tar.gz (19 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.10.tar.gz (19 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.9.tar.gz (19 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.8.tar.gz (19 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.7.tar.gz (19 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.6.tar.gz (18 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.5.tar.gz (18 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.4.tar.gz (18 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.3.tar.gz (18 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.2.tar.gz (18 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.1.tar.gz (18 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.3.0.tar.gz (16 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.22.tar.gz (15 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.21.tar.gz (14 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.20.tar.gz (13 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.19.tar.gz (13 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.18.tar.gz (13 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.17.tar.gz (12 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.16.tar.gz (12 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.15.tar.gz (12 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.14.tar.gz (11 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.13.tar.gz (11 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.12.tar.gz (11 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.11.tar.gz (11 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.10.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.9.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.8.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.7.tar.gz (11 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.6.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.5.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.4.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.3.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.2.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'done'
Downloading uvicorn-0.2.1.tar.gz (10 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'error'
error: subprocess-exited-with-error

Γ— python setup.py egg_info did not run successfully.
β”‚ exit code: 1
╰─> [8 lines of output]
Traceback (most recent call last):
File "", line 2, in
File "", line 34, in
File "/tmp/pip-install-xv61maro/uvicorn_45ae9303eefb4330b54c173fd50eba19/setup.py", line 41, in
long_description=get_long_description(),
File "/tmp/pip-install-xv61maro/uvicorn_45ae9303eefb4330b54c173fd50eba19/setup.py", line 23, in get_long_description
return open('README.md', 'r').read()
FileNotFoundError: [Errno 2] No such file or directory: 'README.md'
[end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

Γ— Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

[notice] A new release of pip available: 22.3.1 -> 23.2.1
[notice] To update, run: python -m pip install --upgrade pip

--> ERROR: process "/bin/sh -c pip install --no-cache-dir ${SDK}==${SDK_VERSION} spaces==${PYSPACES_VERSION}" did not complete successfully: exit code: 1

</details>

Best practice for tracking the best *and* latest?

Hello!

From my understanding, it seems that when you provide a best_fn to CheckpointManagerOptions, it no longer saves checkpoints based off recency.

What if I want to save checkpoints based off both "goodness" and recency? It seems one way to do that is to have two instances of CheckpointManager, but that seems inelegant... surely it should be possible to have only one right?

ValueError: NOT FOUND when trying to save train state in docker container

I'm getting the following error when I try to save my train state from within a docker container:

Traceback (most recent call last):
  File "/project/test.py", line 32, in <module>
    checkpointer.save(os.path.abspath('checkpoints/checkpoint1'), state)
  File "/usr/local/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py", line 81, in save
    self._handler.save(tmpdir, item, *args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 756, in save
    asyncio.run(async_save(directory, item, *args, **kwargs))
  File "/usr/local/lib/python3.9/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/local/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
    return future.result()
  File "/usr/local/lib/python3.9/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 754, in async_save
    future.result()  # Block on result.
ValueError: NOT_FOUND: Error opening "cast" driver: Error opening "zarr" driver: Error writing "step/.zarray" in OCDBT database at local file "/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/": Error writing local file "/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/manifest.ocdbt": Error getting file info: /project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/manifest.ocdbt.__lock [OS error: No such file or directory] [tensorstore_spec[1]='{\"base\":{\"create\":true,\"driver\":\"zarr\",\"dtype\":\"int64\",\"kvstore\":{\"base\":{\"driver\":\"file\",\"path\":\"/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/\"},\"cache_pool\":\"cache_pool#ocdbt\",\"config\":{\"max_decoded_node_bytes\":100000000,\"max_inline_value_bytes\":1024},\"driver\":\"ocdbt\",\"experimental_read_coalescing_threshold_bytes\":1000000,\"path\":\"step/\"},\"metadata\":{\"chunks\":[],\"compressor\":{\"id\":\"zstd\",\"level\":1},\"shape\":[]},\"open\":true,\"recheck_cached_data\":false,\"recheck_cached_metadata\":false},\"context\":{\"cache_pool\":{},\"cache_pool#ocdbt\":{\"total_bytes_limit\":100000000},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"ocdbt_coordinator\":{}},\"driver\":\"cast\",\"dtype\":\"int64\"}'] [source locations='tensorstore/kvstore/kvstore.cc:268\ntensorstore/kvstore/kvstore.cc:268\ntensorstore/driver/driver.cc:114\ntensorstore/driver/driver.cc:114'] [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"cache_pool#ocdbt\":{\"total_bytes_limit\":100000000},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"ocdbt_coordinator\":{}},\"create\":true,\"driver\":\"zarr\",\"dtype\":\"int64\",\"kvstore\":{\"base\":{\"driver\":\"file\",\"path\":\"/project/checkpoints/checkpoint1.orbax-checkpoint-tmp-1690941767098696/\"},\"cache_pool\":\"cache_pool#ocdbt\",\"config\":{\"max_decoded_node_bytes\":100000000,\"max_inline_value_bytes\":1024},\"driver\":\"ocdbt\",\"experimental_read_coalescing_threshold_bytes\":1000000,\"path\":\"step/\"},\"metadata\":{\"chunks\":[],\"compressor\":{\"id\":\"zstd\",\"level\":1},\"shape\":[]},\"open\":true,\"recheck_cached_data\":false,\"recheck_cached_metadata\":false}']

Here's the code to reproduce:

import flax.linen as nn
from flax.training import train_state
import optax
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp
import os

def create_train_state(module, rng):
    x = (jnp.ones([1, 256, 256, 1]))
    variables = module.init(rng, x)
    params = variables['params']
    tx = optax.adam(1e-3)
    ts = train_state.TrainState.create(
        apply_fn=module.apply, params=params, tx=tx
    )
    return ts
    
class TestModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(4, kernel_size=(3, 3))(x)
        return x
    
if __name__ == '__main__':
    init_rng = jax.random.PRNGKey(0)
    model = TestModel()
    state = create_train_state(model, init_rng)
    del init_rng

    checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=True))
    checkpointer.save(os.path.abspath('checkpoints/checkpoint1'), state)

And here's my docker setup:

Dockerfile:

FROM python:3.9.17-slim-bullseye

WORKDIR /project
COPY requirements.txt requirements.txt

RUN python -m pip install --upgrade pip
RUN python -m pip install jupyterlab flax orbax-checkpoint jax

EXPOSE 8888

ENTRYPOINT ["jupyter", "lab", "--ip=0.0.0.0", "--allow-root", "--no-browser", "--NotebookApp.token=''", "--NotebookApp.password=''"]

docker-compose.yaml:

services:
  test:
    build: .
    ports:
      - 8888:8888
    volumes:
      - .:/project
    deploy:
      resources:
        reservations:
          devices:
          - capabilities: [gpu]

I use the following commands to build and enter my docker container:

docker-compose build
docker-compose up -d
docker-compose exec test bash

Then I create the checkpoint directory:

mkdir checkpoints

From here you can run the reproduction code.

I've been able to reproduce this error in a couple of different docker environments, but this one is the simplest. For some reason it does not reproduce in Colab.

save_checkpoint_multiprocess gets stuck on a pod slices

Hello,
I'm training a model on a TPUv4-16, but when the training script calls save_checkpoint_multiprocess, the program gets stuck indefinitely.

Some other notes:

  • Same code works when using a TPUv4-8 using save_checkpoint.
  • I also tried using save_checkpoint with TPUv4-16 and the same outcome of the program getting stuck indefinitely.
  • If I remove save_checkpoint_multiprocess from my code, the script works as expected.

Here is a code snippet:

    jax.distributed.initialize()

    def save_model(state, step=0):
        ckpt = {'model' : state}
        #orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
        async_checkpointer = orbax.AsyncCheckpointer(orbax.PyTreeCheckpointHandler(), timeout_secs=30)
        print("Created async checkpointer")
        # Stuck in the line below
        checkpoints.save_checkpoint_multiprocess(ckpt_dir=checkpoint_path, target=ckpt, prefix=f'checkpoint_',step=step,overwrite=True, keep=1, orbax_checkpointer=async_checkpointer)
        print("checkpoint saved")

# some code later
    if iter_num % eval_interval == 0 and jax.process_index() == 0:
                print(f" {jax.process_index()} is evaluating")
                eval_metrics = evaluate(rng, p_eval_step, state.params['params'])
                train_loss = jax.device_get(eval_metrics['train'])
                val_loss = jax.device_get(eval_metrics['val'])
                writer.add_scalar('train/loss', train_loss,global_step=iter_num)
                writer.add_scalar('val/loss', val_loss,global_step=iter_num)
                lr = jax.device_get(unreplicate(learning_rate_fn(state.step)))
                writer.add_scalar('lr', lr,global_step=iter_num)
                if val_loss < best_eval:
                    best_eval = val_loss
                    if jax.process_index() == 0:
                        print("saving model")
                        save_model(unreplicate(state), step=iter_num)

The model code can be found here

Full training script:

from functools import partial
import argparse
import os
import math
import time
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import jax
import jax.numpy as jnp
import flax
from flax.core.frozen_dict import unfreeze
from flax.training import train_state, checkpoints
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard, shard_prng_key

import optax
import orbax.checkpoint as orbax

from model import GPTConfig, GPT
from configs.shakespeare import config as shakespeare_config
from configs.openwebtext10k import config as openwebtext10k_config

# device_count = jax.device_count()
# local_device_count = jax.local_device_count()

jax.distributed.initialize()

import tiktoken
enc = tiktoken.get_encoding("gpt2")


model_config_args = {
    'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
    'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
    'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
    'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}

def get_train_config(train_config_name):
    train_config = None
    if train_config_name == 'shakespeare':
        train_config = shakespeare_config
    elif train_config_name == 'openwebtext-10k':
        train_config = openwebtext10k_config
    return train_config

def count_params(params) -> int:
    p = jax.tree_util.tree_map(lambda a: a.size if isinstance(a, jnp.ndarray) else 0, params)
    return jax.tree_util.tree_reduce(lambda a, b: a + b, p)

def train(opt):
    seed = 250
    config = GPTConfig()

    train_config = get_train_config(opt.config)
    model_config = model_config_args[train_config['init_from']]
    config.n_layer = model_config['n_layer']
    config.n_head = model_config['n_head']
    config.n_embd = model_config['n_embd']
    config.block_size = train_config['block_size']

    # Eval
    out_dir = train_config['out_dir']
    log_dir = os.path.join(out_dir,'logs')
    writer = SummaryWriter(log_dir=log_dir)
    eval_interval = train_config['eval_interval']
    eval_iters = train_config['eval_iters']
    best_eval = 1e6

    # Checkpoints
    checkpoint_path = os.path.join(out_dir,'checkpoints')
    def restore_model(state, step=0):
        state_restored = checkpoints.restore_checkpoint(ckpt_dir=checkpoint_path, target=state, step=step)
        return state_restored

    def save_model(state, step=0):
        ckpt = {'model' : state}
        #orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
        async_checkpointer = orbax.AsyncCheckpointer(orbax.PyTreeCheckpointHandler(), timeout_secs=30)
        print("Created async checkpointer")

        checkpoints.save_checkpoint_multiprocess(ckpt_dir=checkpoint_path, target=ckpt, prefix=f'checkpoint_',step=step,overwrite=True, keep=1, orbax_checkpointer=async_checkpointer)
        print("checkpoint saved")

    # Generate
    gen_interval = train_config['gen_interval']
    def log_generations(text, step):
        writer.add_text("generation",text,global_step=step)

    # Data
    batch_size = train_config['batch_size']
    dataset = train_config['dataset']
    data_dir = os.path.join('data', dataset)
    train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    def get_batch(split, rng, batch_size):
        data = train_data if split == 'train' else val_data
        ix = jax.random.randint(rng, (batch_size,), minval=0, maxval=(len(data) - config.block_size))
        x = jnp.stack([data[i:i+config.block_size] for i in ix],dtype=jnp.int32)
        y = jnp.stack([data[i+1:i+1+config.block_size] for i in ix],dtype=jnp.int32)
        return x, y

    iter_num = 0

    # Training
    rng = jax.random.PRNGKey(seed)
    rng, init_rng = jax.random.split(rng)
    train_batch_size = batch_size * jax.device_count()
    input_shape = (batch_size, config.block_size)
    model = GPT(config)
    main_rng, init_rng, dropout_init_rng = jax.random.split(rng, 3)
    params = jax.jit(model.init)({'params' : init_rng, 'dropout' : dropout_init_rng}, jax.random.randint(init_rng, input_shape, minval=0,maxval=config.vocab_size))
    print("Number of params : ",count_params(params))
    # Optimizer
    # learning rate decay settings
    learning_rate= train_config['learning_rate']
    max_iters = train_config['max_iters']
    lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
    warmup_iters = max_iters // 300 # how many steps to warm up for
    min_lr = learning_rate / 10 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla

    def create_learning_rate_schedule():
        return optax.warmup_cosine_decay_schedule(
                init_value=0,
                peak_value=learning_rate,
                warmup_steps=2000,
                decay_steps=lr_decay_iters,
                end_value=min_lr
            )
    def create_adamw_mask(params, prev_key=None):
        retval = {}
        for key in params.keys():
            val = params[key]
            # print("prev_key:",prev_key," | key:",key," | val type: ",type(val)," | val:",val)
            if isinstance(val, flax.core.frozen_dict.FrozenDict):
                retval[key] = create_adamw_mask(val,key)
            else:
                if "ln_" in key or "bias" in key or "embedding" in key:
                    retval[key] = False
                else:
                    retval[key] = True
        # print(retval)
        return retval

    learning_rate_fn = create_learning_rate_schedule()
    decay_mask = create_adamw_mask(params)
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.add_decayed_weights(1e-2, mask=decay_mask),
        optax.adamw(
            learning_rate=learning_rate_fn,
            b1=0.9, b2=0.95, eps=1e-8
        )
    )
    state = train_state.TrainState.create(
        apply_fn=model.apply, 
        params=unfreeze(params), 
        tx=optimizer)

    del params

    state = jax_utils.replicate(state)


    def temperature_sample(idx, params, config, max_new_tokens,temperature=1.0, top_k=20, rng=jax.random.PRNGKey(0)):
        sampling_loop_init_state = (jnp.array(0), idx, params, rng)
        def select_top_k(tensor, k):
            values, _ = jax.lax.top_k(tensor, k)
            mask = tensor > values.min()
            return mask, jnp.where(mask, tensor, 0.)
        def log(t, eps = 1e-20):
            return jnp.log(t + eps)
        def gumbel_noise(rng, shape):
            noise = jax.random.uniform(rng, shape = shape, minval = 0., maxval = 1.)
            return -log(-log(noise))
        def sampling_loop_cond_fn(state):
            (i, _, _, _) = state
            return i <= max_new_tokens

        def sampling_loop_body_fn(state):
            i, idx, params, rng = state
            rng0, rng1 = jax.random.split(rng)
            model = GPT(config)
            logits = model.apply({'params' : params}, idx, train=False)
            logits = logits[:,-1,:]
            noise = gumbel_noise(rng0, logits.shape)
            if top_k:
                mask, logits = select_top_k(logits, top_k)
                noise *= mask
            
            logits += noise
            sampled_ind = np.argmax(logits, axis = -1)
            idx = jnp.concatenate((idx, jnp.array([sampled_ind])), axis=1)
            idx_cond = idx if idx.shape[1] <= config.block_size else idx[:,-config.block_size:]
            return (i+1, idx_cond, params, rng1)

        final_state = jax.lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn, sampling_loop_init_state)
        return final_state[1]

    def eval_step(params, batch):
        inputs, targets = batch
        logits = model.apply({'params' : params}, inputs, train=False)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
        return loss

    def evaluate(rng, p_eval_step, params):
        out = {}
        for split in ['train', 'val']:
            losses = jnp.zeros(eval_iters)
            for k in range(eval_iters):
                rng, input_rng = jax.random.split(rng)
                batch = get_batch(split,input_rng, train_batch_size)
                batch = shard(batch)
                loss = p_eval_step(params, batch)
                losses = losses.at[k].set(loss[0])
            out[split] = losses.mean()
        return out


    def train_step(state, batch, dropout_rng=None):
        inputs, targets = batch
        dropout_rng = jax.random.fold_in(dropout_rng, state.step)

        def compute_loss(params):
            logits = state.apply_fn({'params': params['params']}, inputs, train=True, rngs={'dropout' : dropout_rng})
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
            return loss
        
        grad_fn = jax.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = {"loss" : loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics


    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
    p_eval_step = jax.pmap(eval_step, "batch")

    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    print("starting training...")
    for _ in range(max_iters):
        rng, input_rng = jax.random.split(rng)

        # Generate single example
        if iter_num % gen_interval == 0 and jax.process_index() == 0:
            print(f" {jax.process_index()} is generating")
            x,_ = get_batch('train',input_rng, 1)
            generation = temperature_sample(x, unreplicate(state).params['params'], config, max_new_tokens=50,temperature=1.0, top_k=20, rng=rng)
            generation = generation.squeeze()
            generation = enc.decode(generation)
            log_generations(generation, iter_num)
            writer.flush()

        # eval
        if iter_num % eval_interval == 0 and jax.process_index() == 0:
            print(f" {jax.process_index()} is evaluating")
            eval_metrics = evaluate(rng, p_eval_step, state.params['params'])
            train_loss = jax.device_get(eval_metrics['train'])
            val_loss = jax.device_get(eval_metrics['val'])
            writer.add_scalar('train/loss', train_loss,global_step=iter_num)
            writer.add_scalar('val/loss', val_loss,global_step=iter_num)
            lr = jax.device_get(unreplicate(learning_rate_fn(state.step)))
            writer.add_scalar('lr', lr,global_step=iter_num)
            if val_loss < best_eval:
                best_eval = val_loss
                if jax.process_index() == 0:
                    print("saving model")
                    save_model(unreplicate(state), step=iter_num)
                # Restore checkpoint to validate we can load checkpoints.
                # state = restore_model(unreplicate(state), step=iter_num)
                # state = jax_utils.replicate(state)

            writer.flush()
        # Train
        batch = get_batch('train',input_rng, train_batch_size)
        batch = shard(batch)
        state, train_metric = p_train_step(state, batch, dropout_rngs)
        train_metric = unreplicate(train_metric)
        if jax.process_index() == 0:
            print(f" iter {iter_num}, loss : {train_metric['loss']}, lr: {learning_rate_fn(state.step)}")
        iter_num+=1

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Config. Ex: shakespeare, openwebtext-10k or openwebtext"
    )
    opt = parser.parse_args()
    train(opt)

Thank you!

Dev dependencies should be made optional

I looks like Orbax's dependencies currently include dev-dependencies, e.g. pytest on L33:

orbax/pyproject.toml

Lines 23 to 36 in 55454ac

dependencies = [
'absl-py',
'cached_property',
'importlib_resources',
'etils',
'flax',
'importlib_resources',
'jax',
'jaxlib',
'numpy',
'pytest',
'pyyaml',
'tensorstore >= 0.1.20',
]

These should be turned into optional dependencies, as shown in PEP 631.

Due to Flax's dependency on Orbax, all Flax users currently install these dev dependencies.

Unable to `pip install orbax-checkpoint`

$ pip install orbax-checkpoint
ERROR: Could not find a version that satisfies the requirement orbax-checkpoint (from versions: none)
ERROR: No matching distribution found for orbax-checkpoint
WARNING: You are using pip version 20.0.2; however, version 23.1.1 is available.
You should consider upgrading via the '/Tmp/slurm.3120526.0/jaxrl/bin/python -m pip install --upgrade pip' command.

$ pip install --upgrade pip
Collecting pip
  Downloading pip-23.1.1-py3-none-any.whl (2.1 MB)
     |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2.1 MB 32.5 MB/s 
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 20.0.2
    Uninstalling pip-20.0.2:
      Successfully uninstalled pip-20.0.2
Successfully installed pip-23.1.1

$ pip install orbax-checkpoint
ERROR: Ignored the following versions that require a different python version: 0.0.0 Requires-Python >=3.8; 0.1.1 Requires-Python >=3.8; 0.1.4 Requires-Python >=3.8; 0.1.6 Requires-Python >=3.8; 0.1.7 Requires-Python >=3.8; 0.1.8 Requires-Python >=3.8; 0.2.0 Requires-Python >=3.8; 0.2.1 Requires-Python >=3.8
ERROR: Could not find a version that satisfies the requirement orbax-checkpoint (from versions: none)
ERROR: No matching distribution found for orbax-checkpoint

orbax is unable to save dm-haiku pytrees.

Hi,

I am experimenting with using orbax for saving checkpoints which consists of dm-haiku state and parameters. However, it seems that the pytree checkpointer is unable to handle "~" characters which are used by inside haiku params/state.

I simple reproducing example is

from absl import app
import orbax.checkpoint
import jax.numpy as jnp
import pathlib


def main(_):
    ckpt_dir = pathlib.Path("logs")
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    state = {"a/~/foo": jnp.ones((2, 3))} # Changing to "a" works fine.
    ckpt_manager = orbax.checkpoint.CheckpointManager(
        ckpt_dir,
        orbax.checkpoint.PyTreeCheckpointer(),
    )
    print(ckpt_manager.latest_step())
    ckpt_manager.save(0, state)

if __name__ == "__main__":
    app.run(main)

How to restore a checkpoint with only partial knowledge of it?

Hi,

After several training runs, I have existing checkpoints that have dict(params=..., opt_state=...). However, for debugging purposes, I also want to store the last executed data batch in new checkpoints, like this:

dict(params=..., opt_state=..., last_batch=...)

Is there a way to restore both types of these checkpoints?

I had been using the idiom:

init_state = dict(params=params, opt_state=opt_state)
shardings = jax.tree_map(lambda x: x.sharding, init_state)
restore_args = utils.construct_restore_args(init_state, shardings)
restored = manager.restore(hps.resume_ckpt, items=init_state, restore_kwargs=
        {'restore_args': restore_args})

Although it would have been convenient to use:

restored = manager.restore(hps.resume_ckpt, items=None, restore_kwargs=None)

unfortunately, the opt_state requires special handling, so the items=None form for restore doesn't work for that.

I wondered if there's any way to get the best of both worlds - the flexibility of items=None form to restore checkpoints of different shape (with or without the last_batch) but also provide special restore_kwargs for just one item in the checkpoint (the opt_state).

Thanks!

jax.Array must be fully replicated to be saved in aggregate file

I'm trying to save a checkpoint and getting this error message. Saving code :

ckpt = {'state': state, 'config': model.config} 
save_args = orbax_utils.save_args_from_target(ckpt)
checkpoint_manager.save(global_step + 1, ckpt, save_kwargs={'save_args': save_args})

This line in orbax/checkpoint/pytree_checkpoint_handler.py is throwing the error :

if isinstance(value, jax.Array) and not value.is_fully_replicated:
     raise ValueError(
         'jax.Array must be fully replicated to be saved in aggregate file.'
     )

state is an instance of flax.training.train_state. What could be causing this? I tried disabling jax.Array with jax.config.update('jax_array', False) but that does not work with jax and jaxlib 0.4.7.

Is save/load with different number of devices supported?

I am working from a maxtext fork and trying to load a checkpoint saved with four devices using only 2 devices (for evaluation).
Is this supported? My code hangs at the checkpoint load and I'm not sure whether its my code or the behavior is not supported? If the former, a pointer would be helpful!

def map_to_pspec(data, pspec):
    if isinstance(data, (jax.Array, jax.ShapeDtypeStruct)) and pspec is not None:
        return type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec)
    else:
        return type_handlers.RestoreArgs()

restore_args = jax.tree_util.tree_map(
        map_to_pspec, abstract_unboxed_pre_state, state_mesh_annotations
    )
latest_step =  checkpoint_manager.latest_step()
raw_state = checkpoint_manager.restore(
            latest_step,
            {"model_state": abstract_unboxed_pre_state},
            {"model_state": {"restore_args": restore_args}},
        )
return raw_state["model_state"]

Restoring checkpoint throws a 'FAILED PRECONDITION' error

Hi, I'm quite new to Orbax, and I am trying to use it to checkpoint my flax models. The contents of the checkpoint are vanilla flax TrainState objects, which I save as ckpt_manager.save(epoch, {'state': state}, metrics = {'accuracy': accuracy}).

I then try to restore the checkpoint with ckpt_manager.restore(ckpt_manager.best_step()), but I get back
ValueError: FAILED_PRECONDITION: Error opening "zarr" driver: Expected "compressor" of {"id":"gzip","level":1} but received: {"id":"zstd","level":1} [source locations='tensorstore[/driver/driver.cc:117](https://vscode-remote+ssh-002dremote-002balienware.vscode-resource.vscode-cdn.net/driver/driver.cc:117)'] [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128}},\"driver\":\"zarr\",\"kvstore\":{\"driver\":\"file\",\"path\":\"[/checkpoints/1352/default/state.opt_state.0.mu.params_1.Dense_0.kernel/](https://vscode-remote+ssh-002dremote-002balienware.vscode-resource.vscode-cdn.net/checkpoints/1352/default/state.opt_state.0.mu.params_1.Dense_0.kernel/)\"},\"metadata\":{\"compressor\":{\"id\":\"gzip\",\"level\":1}}}']

In my python environment I have installed the latest versions available on PyPi of both orbax and tensorstore. Thanks for your help!

`orbax-0.1.8`: Bad package release?

Quick fix: Use pip install orbax==0.1.7 until this issue is fixed.

Something seems to have gone wrong:

$ !pip download orbax
Collecting orbax==0.1.8
  Using cached orbax-0.1.8.tar.gz (1.6 kB)
  error: subprocess-exited-with-error
  
  Γ— python setup.py egg_info did not run successfully.
  β”‚ exit code: 1
  ╰─> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... error
error: metadata-generation-failed

Γ— Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

Downloading the package manually and inspecting its contents (from PyPI: https://pypi.org/project/orbax/0.1.8/#files):

$ wget https://files.pythonhosted.org/packages/1a/b2/bf13b6c0f73952d4ce11c13f439b1fbd688ad9a9a875ab5cf6106412644d/orbax-0.1.8.tar.gz
$ tar xzvf orbax-0.1.8.tar.gz
orbax-0.1.8/
orbax-0.1.8/PKG-INFO
orbax-0.1.8/README.md
orbax-0.1.8/orbax.egg-info/
orbax-0.1.8/orbax.egg-info/PKG-INFO
orbax-0.1.8/orbax.egg-info/SOURCES.txt
orbax-0.1.8/orbax.egg-info/dependency_links.txt
orbax-0.1.8/orbax.egg-info/top_level.txt
orbax-0.1.8/setup.cfg
orbax-0.1.8/setup.py

It should look something like the previous release:

$ wget https://files.pythonhosted.org/packages/15/44/d8a13c81c47302440e861d755e496742df25d293d49c935aeb3e05f04ce3/orbax-0.1.7.tar.gz
$ tar xzvf orbax-0.1.7.tar.gz
orbax-0.1.7/LICENSE
orbax-0.1.7/README.md
orbax-0.1.7/orbax/__init__.py
orbax-0.1.7/orbax/checkpoint/README.md
orbax-0.1.7/orbax/checkpoint/__init__.py
orbax-0.1.7/orbax/checkpoint/abstract_checkpointer.py
orbax-0.1.7/orbax/checkpoint/aggregate_handlers.py
orbax-0.1.7/orbax/checkpoint/array_checkpoint_handler.py
orbax-0.1.7/orbax/checkpoint/async_checkpoint_handler.py
orbax-0.1.7/orbax/checkpoint/async_checkpointer.py
orbax-0.1.7/orbax/checkpoint/checkpoint_handler.py
orbax-0.1.7/orbax/checkpoint/checkpoint_manager.py
orbax-0.1.7/orbax/checkpoint/checkpoint_utils.py
orbax-0.1.7/orbax/checkpoint/checkpoint_utils_test.py
orbax-0.1.7/orbax/checkpoint/checkpointer.py
orbax-0.1.7/orbax/checkpoint/future.py
orbax-0.1.7/orbax/checkpoint/json_checkpoint_handler.py
orbax-0.1.7/orbax/checkpoint/json_checkpoint_handler_test.py
orbax-0.1.7/orbax/checkpoint/lazy_utils.py
orbax-0.1.7/orbax/checkpoint/msgpack_utils.py
orbax-0.1.7/orbax/checkpoint/orbax_checkpoint.ipynb
orbax-0.1.7/orbax/checkpoint/pytree_checkpoint_handler.py
orbax-0.1.7/orbax/checkpoint/test_utils.py
orbax-0.1.7/orbax/checkpoint/transform_utils.py
orbax-0.1.7/orbax/checkpoint/transform_utils_test.py
orbax-0.1.7/orbax/checkpoint/type_handlers.py
orbax-0.1.7/orbax/checkpoint/utils.py
orbax-0.1.7/orbax/checkpoint/utils_test.py
orbax-0.1.7/orbax/conftest.py
orbax-0.1.7/pyproject.toml
orbax-0.1.7/PKG-INFO

=> The orbax-0.1.8 package should probably be yanked, and a new package orbax-0.1.9 should be released.

Ideally, the release process would be automated via Github action, like for example here:
https://github.com/google/flax/blob/main/.github/workflows/pythonpublish.yml

make tensorflow optional orbax-export.

Hi,

Is it possible to make TensorFlow optional in Export?

From the other TF related repos it's quite common to make TensorFlow optional as the user can choose between different TF versions (e.g., tensorflow-cpu, tensorflow-macos).

Created #365

orbax equivalent to load_state_dict(strict=False)

I am trying to convert checkpoints to orbax format and getting a bad traceback:

  File "/home/sam/.conda/envs/char-latest/lib/python3.8/site-packages/orbax/checkpoint/utils.py", line 260, in _reconstruct_from_keypath
    if not isinstance(result, list) and key_name not in result:
TypeError: argument of type 'LazyValue' is not iterable

Higher level problem: is it possible to tell PyTreeCheckpointHandler to not load params that aren't present. like load_state_dict(strict=False) in pytorch?

orbax-checkpoint==0.2.3

Checkpointing fails if a key includes a forward slash

To reproduce:

import orbax.checkpoint

CHECKPOINT_DIR = "./checkpoints"
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
tree = {"test/test": 42}
checkpointer.save(CHECKPOINT_DIR, tree)

Raises

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints.orbax-checkpoint-tmp-1690216561285472/test/test'

This is a problem in particular for haiku parameter names, which can include forward slashes by default.

error saving state

Hello I get error as below while trying to save state in pmapped model

Error parsing object member "base": Error parsing object member "metadata": Error parsing object member "chunks": Array has length 0 but should have length 1

code is pretty basic

        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        ckpt = {'state': state} #, 'config': cfg  ,'loss':loss
        save_args = orbax_utils.save_args_from_target(ckpt)
        orbax_checkpointer.save(chechpoint_epoch_folder, ckpt, save_args=save_args)

versions

Python 3.10.12
orbax-checkpoint             0.2.3
jax                          0.4.10
flax                         0.7.0

Generally code is run inside official jax nvidia container

nvcr.io/nvdlfwea/jax/jax:23.05-py3
model is distributed but not sharded over 2 gpus

full error

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspaces/Jax_cuda_med/j_med/ztest_geometric/run_geometric_sv.py", line 341, in <module>
    main_train(cfg)
  File "/workspaces/Jax_cuda_med/j_med/ztest_geometric/run_geometric_sv.py", line 317, in main_train
    state,loss=train_epoch(batch_images[index,:,:,:,:,:],batch_labels[index,:,:,:,:,:],batch_images_prim,curr_label,epoch,index
  File "/workspaces/Jax_cuda_med/j_med/ztest_geometric/run_geometric_sv.py", line 241, in train_epoch
    save_checkpoint(index,epoch,cfg,checkPoint_folder,state,np.mean(epoch_loss))
  File "/workspaces/Jax_cuda_med/j_med/ztest_geometric/data_utils.py", line 76, in save_checkpoint
    orbax_checkpointer.save(chechpoint_epoch_folder, ckpt, save_args=save_args)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpointer.py", line 77, in save
    self._handler.save(tmpdir, item, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 421, in save
    asyncio.run(async_save(directory, item, *args, **kwargs))
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 414, in async_save
    commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 394, in async_save
    commit_futures = await asyncio.gather(*copy_futures)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 386, in serialize
    return await handler.serialize(value, info, args)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py", line 502, in serialize
    await serialization.async_serialize(
  File "/usr/local/lib/python3.10/dist-packages/jax/experimental/array_serialization/serialization.py", line 153, in async_serialize
    ts.Spec(tensorstore_spec),
ValueError: Error parsing object member "base": Error parsing object member "metadata": Error parsing object member "chunks": Array has length 0 but should have length 1 [source locations='tensorstore/internal/json_binding/json_binding.h:859,tensorstore/internal/json_binding/json_binding.h:859,tensorstore/internal/json_binding/json_binding.h:859']

Thanks for any clue how to resolve it

Any way to have CheckpointManager write earlier checkpoints?

Hi,

I have a directory of checkpoints from a previous run in which I had set

    options = ocp.CheckpointManagerOptions(save_interval_steps=3000, max_to_keep=50)

I ran the code longer than expected, and now have 50 checkpoints:

126000
129000
132000
135000
...
273000

I had wanted to have all of the checkpoints starting from 0, so I should have chosen a larger max_to_keep. I restarted from the beginning, now using max_to_keep=200, but it seems that the CheckpointManager is not saving the checkpoints 0, 3000, 6000, etc, even though the directory doesn't have 200 items yet.

Is this expected behavior, and is there any easy way I can get Orbax to fill in older checkpoints up to max_to_keep and then start to delete them based on the checkpoint order, not the order in which they were written?

Thanks,

Henry

asyncio error with `orbax-checkpoint` in Python <= 3.9

checkpointer.restore throws asyncio error:

  File "/Users/ivyzheng/envs/tmp39/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py", line 97, in restore
    restored = self._handler.restore(directory, *args, item=item, **kwargs)
  File "/Users/ivyzheng/envs/tmp39/lib/python3.9/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 515, in restore
    byte_limiter = serialization._LimitInFlightBytes(concurrent_bytes)  # pylint: disable=protected-access
  File "/Users/ivyzheng/envs/tmp39/lib/python3.9/site-packages/jax/experimental/gda_serialization/serialization.py", line 127, in __init__
    self._cv = asyncio.Condition(lock=asyncio.Lock())
  File "/usr/local/Cellar/[email protected]/3.9.15/Frameworks/Python.framework/Versions/3.9/lib/python3.9/asyncio/locks.py", line 81, in __init__
    self._loop = events.get_event_loop()
  File "/usr/local/Cellar/[email protected]/3.9.15/Frameworks/Python.framework/Versions/3.9/lib/python3.9/asyncio/events.py", line 642, in get_event_loop
    raise RuntimeError('There is no current event loop in thread %r.'

This error only appears for python <= 3.9, and when installing only orbax-checkpoint, not orbax.

Repro from a clean Py39 environment

pip install orbax-checkpoint, then

import orbax.checkpoint
import numpy as np

pytree = {'a': np.arange(5), 'b': np.ones((5))}
ckptr = orbax.checkpoint.PyTreeCheckpointer()
path = './tmp/ckpt1'

ckptr.save(path, pytree, force=True)
restored = ckptr.restore(path)
print(restored)

Loading only a part of the saved checkpoint (i.e. only model weights)

I want to implement the following functionality:

  • During training I'd like to save the whole flax TrainState object to a checkpoint so that I can use it to continue training.
  • When loading a model for inference from a checkpoint, I only want to load model params, which is a subtree of the TrainState pytree.

Unfortunately, I haven't been able to figure out how to implement such partial loading in orbax. Any guidance will be much appreciated.

Sharded loading performance question

I'm benchmarking loading a 65B sharded transformer model on multiple GPUs on the same host. The checkpoint itself is not sharded, but when the model is being loaded, a correct sharding is being supplied.

I've noticed several performance-related peculiarities that I would like to understand better, hopefully you can comment on them.

  • For some reason, when the weights of the model are in scan-friendly format (that is, transformer block weights are stacked along the layer dimension in a single tensor), loading on 8 GPUs takes ~4x longer than when the weights for different layers are stored separately. This is surprising as I would expect loading time to be either unaffected or smaller for scan-friendly format.
  • Loading scan-friendly model using 2 GPUs takes 2.5x less time than when using 8 GPUs. This is also weird as I would expect the total number of disk reads to be approximately the same in both cases.
  • When OCDBT is enabled, checkpoint loading seems to take ~5-10% longer. Is this to be expected?

Any clarifications or suggestions how to speed things up would be much appreciated.

AttributeError: 'Config' object has no attribute 'jax_coordination_service'

I am using orbax to checkpoint my models, but am getting the error when I call checkpoint_manager.save:

AttributeError                            Traceback (most recent call last)
[
File [.../lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:465](.../lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:465), in CheckpointManager.save(self, step, items, save_kwargs, metrics, force)
    400 def save(self,
    401          step: int,
    402          items: Union[Any, Mapping[str, Any]],
   (...)
    405          metrics: Optional[PyTree] = None,
    406          force: Optional[bool] = False) -> bool:
    407   """Saves the provided items.
    408 
    409   This method should be called by all hosts - process synchronization and
   (...)
    463     ValueError: if the checkpoint already exists.
    464   """
--> 465   if not force and not self.should_save(step):
    466     return False
    467   if self.reached_preemption(step):

File [.../python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:351](.../python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py:351), in CheckpointManager.should_save(self, step)
...
--> 336       jax.config.jax_coordination_service
    337       and multihost_utils.reached_preemption_sync_point(step)
    338   )

AttributeError: 'Config' object has no attribute 'jax_coordination_service'

Here is code to reproduce the bug:

path = 'tmp/my_checkpoint'
options = CheckpointManagerOptions(max_to_keep=1, create=True)
checkpoint_manager = CheckpointManager(directory=path,
                                       checkpointers=PyTreeCheckpointer(),
                                       options=options)
pytree = {'a': 1, 'b': 2}
step = 1
checkpoint_manager.save(step, pytree)

From this JAX issue it looks like jax_coordination_service has been removed. Thanks in advance!

Incompatibility with Haiku

Reopening an issue regarding incompatibility with Haiku naming conventions (similar to previous issue). This is not problematic in v0.3.5

Sample code:

from jax import numpy as jnp
import orbax.checkpoint as ocp
import haiku as hk


@hk.transform
def forward_fn(inputs):
  # net = hk.Linear(output_size=2) # This works
  net = hk.nets.MLP(
      output_sizes=[2, 2], activate_final=True)  # This doesn't work
  return net(inputs)


prng_seq = hk.PRNGSequence(0)
params = forward_fn.init(next(prng_seq), jnp.ones((1, 5)))

ckpt_dir = '/tmp/my-checkpoints/'
orbax_mngr = ocp.CheckpointManager(
    ckpt_dir,
    {'state': ocp.PyTreeCheckpointer()},
    options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
orbax_mngr.save(step=0, items={'state': params})

The error:

Traceback (most recent call last):
  File "/workspaces/modularbayes/examples/bar.py", line 23, in <module>
    orbax_mngr.save(step=0, items={'state': params})
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 496, in save
    self._checkpointers[k].save(item_dir, item, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 79, in save
    self._handler.save(tmpdir, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 818, in save
    asyncio.run(async_save(directory, item, *args, **kwargs))
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 811, in async_save
    commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 786, in async_save
    commit_futures = await asyncio.gather(*serialize_ops)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 893, in serialize
    open_future = ts.open(
ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/mlp/~/linear_0.b" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']
sys:1: RuntimeWarning: coroutine 'async_serialize' was never awaited

Saving and restoring checkpoint hangs in Jupyter Notebook.

I have a training loop as shown below. I am running Python 3.10.10 and the latest versions of JAX (0.4.7), Flax (0.6.7), and Orbax (0.1.6). I am having some issues with the restore and save commands leading to the code hanging in Jupyter Notebook. When I call the train_model function, the code block would freeze at either restore or save but resumes if I run another code block. I think it could potentially have something to do with asyncio, but I am not totally sure. I had recently switched over from flax.checkpoints, where this wasn't an issue. Any help on this would be appreciated!

#Flax imports
from flax import serialization
from flax.training import orbax_utils, train_state

...

def train_model(model_path, dataset_path, dataset_adjustment='normalize',
                epochs=200, random_seed=0, batch_size=4, learning_rate=0.001,
                warmup_epochs=10, decay_epochs=100, decay_rate=0.5, decay_transition_epochs=10,
                optimizer=None, loss_weights=None):

    model_path = Path(model_path)
    model_parent_path = model_path.parent
    model_name = model_path.stem
    checkpoint_path = model_parent_path.joinpath(f'{model_name}_ckpts')
    checkpoint_path.mkdir(parents=True, exist_ok=True)
    batch_metrics_log_path = model_parent_path.joinpath(f'{model_name}_batch_metrics_log')
    epoch_metrics_log_path = model_parent_path.joinpath(f'{model_name}_epoch_metrics_log')

    if batch_metrics_log_path.is_file():
        with open(batch_metrics_log_path, 'r') as f_batch_metrics_log:
            batch_metrics_log = json.load(f_batch_metrics_log)
    else:
        batch_metrics_log = []
    if epoch_metrics_log_path.is_file():
        with open(epoch_metrics_log_path, 'r') as f_epoch_metrics_log:
            epoch_metrics_log = json.load(f_epoch_metrics_log)
    else:
        epoch_metrics_log = []

    print('Loading datasets...\n')
    ds = load_datasets(dataset_path, adjustment=dataset_adjustment)
    train_images_shape = ds['train']['images'].shape
    input_size = train_images_shape[1:3]
    coords_max_length = \
        max([len(coords) for coords in ds['train']['coords']] + [len(coords) for coords in ds['valid']['coords']])

    rng = random.PRNGKey(random_seed)

    warmup = [learning_rate * i / warmup_epochs for i in range(warmup_epochs)]
    constant = [learning_rate] * (epochs - warmup_epochs - decay_epochs)
    decay = [learning_rate * decay_rate ** np.ceil(i / decay_transition_epochs) for i in range(1, decay_epochs + 1)]
    schedule = warmup + constant + decay

    if optimizer is None:
        optimizer = partial(optax.adabelief, eps=1e-8)
    tx = optax.inject_hyperparams(optimizer)(learning_rate=learning_rate)

    if loss_weights is None:
        loss_weights = {
            'rmse': 0.4,
            'bce': 0.2,
            'smoothf1': 1
        }

    mgr_options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2)
    handlers = {'state': orbax.checkpoint.PyTreeCheckpointer()}
    ckpt_mgr = orbax.checkpoint.CheckpointManager(
        directory=checkpoint_path,
        checkpointers=handlers,
        options=mgr_options
    )

    if (next(checkpoint_path.iterdir(), None) is None) and model_path.is_file():
        print(f'Loading existing model weights from {model_path}...\n')
        ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
        variables = ckptr.restore(model_path, item=None)
    else:
        variables = None

    print('Creating new TrainState...\n')
    state = create_train_state(rng, input_size, tx, variables)
    latest_epoch = ckpt_mgr.latest_step()
    if latest_epoch is not None:
        print(f'Loading latest checkpoint from {checkpoint_path}...\n')
        restore_args = orbax_utils.restore_args_from_target(state, mesh=None)
        state = ckpt_mgr.restore(
            step=latest_epoch,
            items={'state': state},
            restore_kwargs={'state': {'restore_args': restore_args}}
        )['state']

    for epoch_learning_rate in schedule:

        state, batch_metrics, epoch_metrics = \
            train_epoch(state, ds, batch_size, loss_weights, epoch_learning_rate, input_size, coords_max_length)

        batch_metrics_log += batch_metrics
        epoch_metrics_log += [epoch_metrics]

        save_args = orbax_utils.save_args_from_target(state)
        ckpt_mgr.save(
            step=state.epoch,
            items={'state': state},
            save_kwargs={'state': {'save_args': save_args}}
        )

        with open(batch_metrics_log_path, 'w') as f_batch_metrics_log:
            json.dump(batch_metrics_log, f_batch_metrics_log, indent=4)
        with open(epoch_metrics_log_path, 'w') as f_epoch_metrics_log:
            json.dump(epoch_metrics_log, f_epoch_metrics_log, indent=4)

    variables = {'params': state.params, 'batch_stats': state.batch_stats, 'input_size': input_size}
    bytes_model = serialization.to_bytes(variables)

    with open(model_path, 'wb') as f_model:
        f_model.write(bytes_model)

Incompatibility with Haiku

I am trying to use Orbax for checkpointing a model in Haiku, but it's been quite problematic.

The problem seems to be that Orbax struggles with keys containing the / character. Haiku uses / for managing parameters in nested modules, such as a simple MLP network.

Here is a minimal example replicating the problem:

import pathlib
from jax import numpy as jnp
import orbax
import orbax.checkpoint
import haiku as hk

@hk.transform
def forward_fn(inputs):
  # net = hk.Linear(output_size=2) # This works
  net = hk.nets.MLP(output_sizes=[2, 2], activate_final=True)  # This doesn't work
  return net(inputs)

prng_seq = hk.PRNGSequence(0)
params = forward_fn.init(next(prng_seq), jnp.ones((1, 5)))

ckpt_dir = str(pathlib.Path.home() / 'foo')
orbax_mngr = orbax.checkpoint.CheckpointManager(
    ckpt_dir,
    {'state': orbax.checkpoint.PyTreeCheckpointer()},
    options=orbax.checkpoint.CheckpointManagerOptions(max_to_keep=1),
)
orbax_mngr.save(step=0, items={'state': params})

For the MLP network, params is a tree with the first level keys ['mlp/~/linear_0', 'mlp/~/linear_1'], which creates a problem when I try to save it using the checkpoint manager.

FileNotFoundError: [Errno 2] No such file or directory: '/home/ubuntu/foo/0.orbax-checkpoint-tmp-1685818375792355/state.orbax-checkpoint-tmp-1685818375868924/mlp/~/linear_0.w'

Main thread blocks for a long time when using GCS paths with AsyncCheckpointer

Hello, great project you guys have here! I recently started using the AsyncCheckpointer / PyTreeCheckpointHandler with GCS paths and noticed that checkpointing blocks the main thread (the one calling CheckpointManager.save()) for about ~30 seconds. During this period there is a small amount of network traffic (~30-50 kB/s) but no bulk data transfer. The bulk data transfers seem to happen after CheckpointManager.save(), as expected.

I did a bit of digging and as far as I can tell, the time is spent creating metadata files for every array in the pytree. The pytree has a couple dozen arrays and these files are created sequentially, which is slow on GCS. Would it be possible to reduce the number of these files or create them in parallel? Or is there a different workaround to save fewer files (note that I am also using sharded arrays)?

How to restore checkpoints if not all arrays are sharded?

Hi,
I am trying to restore a checkpoint where some of its array are not partitioned. I believe this is a common use case, as one may not want (or can't) to partition all parameters in a model. The following is a minimal example:

import numpy as np
import jax
import flax

mesh_shape = (2, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, ('x', 'y'))

sharded = jax.device_put(np.arange(4).reshape(2, 2), jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')))
unsharded = jax.device_put(np.arange(1), jax.sharding.NamedSharding(mesh, PartitionSpec()))
ckpt = dict(sharded=sharded, unsharded=unsharded)

jax.distributed.initialize("localhost:8889", num_processes=1, process_id=0)

async_ckptr = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)
async_ckpt_mgr = orbax.checkpoint.CheckpointManager('/tmp/example', async_ckptr)
async_ckpt_mgr.wait_until_finished()
async_ckpt_mgr.save(0, ckpt)
async_ckpt_mgr.wait_until_finished()

restore_args = flax.training.orbax_utils.restore_args_from_target(ckpt)
restored_ckpt = async_ckpt_mgr.restore(0, items=ckpt, restore_kwargs={'restore_args': restore_args})

I get the following error message:
"ValueError: Sharding of jax.Array cannot be None. Provide mesh and mesh_axes OR sharding."

I then found out that I can transform unsharded into a numpy.array, restore it, then assign again an empty partition. As follows:

ref_ckpt = dict(sharded=sharded, unsharded=np.array(unsharded))
restore_args = flax.training.orbax_utils.restore_args_from_target(ref_ckpt)
restored_ckpt = async_ckpt_mgr.restore(0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args})
restored_ckpt["unsharded"] = jax.device_put(restored_ckpt["unsharded"], NamedSharding(mesh, PartitionSpec()))

This seems to work as desired. Would you say this is the right way to do it? Is there an easier way?

Thanks!

AsyncCheckpointer hang

We are doing a big run and about 20% of the time we call checkpoint_manager.save we get a hang for hours that eventually crashes.

  • When it hangs the last long statement is jax.experimental.array_serialization.serialization:431 | Process 0 successfully set key tensorstore_checkpoint_1 in the kv store -- We never see Thread joined successfully.
  • One workaround I can think of is to use Checkpointer instead of AsyncCheckpointer, but I doubt this will work given that we see the hang before training resumes.
  • Storage target is gs://uscentral2_user/checkpoints/ (Managed with Autoclass)
  • checkpoints are like 200 GB and seem fine (when there's not a hang)
  • usually its not the first save after relaunch that blows up
  • save gets called ~every 15 minutes.
  • Checkpoints normally take 38 seconds to save. The files don't get modified after those 38 seconds so seems as we are not doing any asynchronous work.
  • we reuse 1 instance of `checkpoint_manager for the whole program.

Any ideas for good debugging steps?

Make Checkpoint Manager Code

import epath
from orbax import checkpoint
from orbax.checkpoint.checkpoint_manager import (
    CheckpointManager,
    CheckpointManagerOptions,
    Checkpointer,
    AsyncCheckpointer,
)
from orbax.checkpoint.checkpoint_handler import CheckpointHandler

p = epath.Path("gs://uscentral2_user/checkpoints/")
checkpoint.PyTreeCheckpointHandler(concurrent_gb=100)
h = checkpoint.PyTreeCheckpointHandler(concurrent_gb=100)
mstate_ckptr = AsyncCheckpointer(h)
mngr = checkpoint_mgr_cls(
    p,
    checkpointers={
       'model_state': mstate_ckptr,
        "dataloader_state": Checkpointer(DataloaderHandler()),
    },
    options=CheckpointManagerOptions(create=create, cleanup_tmp_directories=create),
)

# call save 
mngr.save(step, {       'model_state': state,
        "dataloader_state": "dataloader_state"}

Lines we don't see in hang case

jax.experimental.array_serialization.serialization:459 | Thread joined successfully
jax.experimental.array_serialization.serialization:462 | Error check finished successfully
blocking_key_value_get on key tensorstore_checkpoint_1 was successfully completed.

Lines we always get

2023-07-04 11:16:17,348 | INFO | j000 | train:84 | To see full metrics 'tensorboard --logdir=/home/roller/maxtext-train-sweeps/invincible_32b_v2//+project-invincible_32b/invincible_32b_v2/tensorboard/'
2023-07-04 11:16:17,358 | INFO | j000 | jax.experimental.array_serialization.serialization:462 | Error check finished successfully
2023-07-04 11:16:17,359 | INFO | j000 | jax.experimental.array_serialization.serialization:469 | blocking_key_value_get on key tensorstore_checkpoint_0 was successfully completed.
2023-07-04 11:16:18,077 | INFO | j000 | absl:87 | Saving item to gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/model_state. Waiting for thread to finish save.
2023-07-04 11:16:18,077 | INFO | j000 | jax.experimental.array_serialization.serialization:462 | Error check finished successfully
2023-07-04 11:16:18,079 | INFO | j000 | jax.experimental.array_serialization.serialization:469 | blocking_key_value_get on key tensorstore_checkpoint_0 was successfully completed.
2023-07-04 11:16:49,089 | INFO | j000 | jax.experimental.array_serialization.serialization:408 | Starting commit to storage layer by process: 0
2023-07-04 11:16:49,090 | INFO | j000 | absl:67 | Saving item to gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/dataloader_state.
2023-07-04 11:16:52,109 | INFO | j000 | absl:540 | Finished saving checkpoint to `gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/dataloader_state`.
2023-07-04 11:16:52,125 | INFO | j000 | absl:67 | Saving item to gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/config.
2023-07-04 11:16:53,982 | INFO | j000 | absl:540 | Finished saving checkpoint to `gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/config`.
2023-07-04 11:16:53,998 | INFO | j000 | absl:67 | Saving item to gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/train_rng.
2023-07-04 11:16:54,131 | INFO | j000 | jax.experimental.array_serialization.serialization:413 | Finished committing to storage layer by process: 0
2023-07-04 11:16:54,131 | INFO | j000 | jax.experimental.array_serialization.serialization:420 | Key used for barrier is tensorstore_checkpoint_1 for process 0
2023-07-04 11:16:55,722 | INFO | j000 | absl:540 | Finished saving checkpoint to `gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/train_rng`.
2023-07-04 11:16:55,740 | INFO | j000 | train:772 | saved a checkpoint in -38.38 seconds (synchronous)
2023-07-04 11:16:56,775 | INFO | j000 | jax.experimental.array_serialization.serialization:423 | Finished waiting at barrier for process 0
2023-07-04 11:16:56,922 | INFO | j000 | absl:540 | Finished saving checkpoint to `gs://uscentral2_user/checkpoints/invincible_32B/run_v2/+project-invincible_32b/10750/model_state`.
2023-07-04 11:16:56,923 | INFO | j000 | jax.experimental.array_serialization.serialization:428 | on_commit_callback successfully ran!
2023-07-04 11:16:56,938 | INFO | j000 | jax.experimental.array_serialization.serialization:431 | Process 0 successfully set key tensorstore_checkpoint_1 in the kv store

Package(s) vendoring issues

What is supposed way to distribute orbax-checkpoint ΠΈ orbax-export? It is totally unclear for me

  1. how to get notification about new releases?
  2. what is 'ground truth' package(s) sources: PyPI or GitHub?
  3. what is supposed way to deal with conflicts due to package namespace in Linux distros like Arch?

In my perspective the current situation is quite out of the ordinary and tedious to manage it over time. It would be great to known maintainers opinion on the issue.

unable to import orbax.checkpoint

Hi, I am unable to import orbax.checkpoint.

----> 1 import orbax.checkpoint

File ~/.local/lib/python3.10/site-packages/orbax/checkpoint/__init__.py:20
     17 import functools
     19 from orbax.checkpoint import aggregate_handlers
---> 20 from orbax.checkpoint import checkpoint_utils
     21 from orbax.checkpoint import lazy_utils
     22 from orbax.checkpoint import msgpack_utils

File ~/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_utils.py:25
     23 from jax.sharding import Mesh
     24 import numpy as np
---> 25 from orbax.checkpoint import type_handlers
     26 from orbax.checkpoint import utils
     29 PyTree = Any

File ~/.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:49
     44     raise ValueError('Coordinator address not set.')
     45   return coordinator_address.split(':')[0]
     48 def create_coordinator_server_and_context() -> (
---> 49     Tuple[ts.Context, Optional[ts.ocdbt.DistributedCoordinatorServer]]
     50 ):
     51   """Creates OCDBT coordinator and Tensorstore context.
     52 
     53   This function must be called at the start of the program across all processes
   (...)
     70     Tuple of ts.Context and OCDBT coordinator server object.
     71   """
     72   jax_global_state = jax._src.distributed.global_state  # pylint: disable=protected-access

AttributeError: module 'tensorstore' has no attribute 'ocdbt'

Here are the following versions:

In [13]: sys.version
Out[13]: '3.10.10 (main, Mar  5 2023, 22:26:53) [GCC 12.2.1 20230201]'

In [14]: jax.__version__
Out[14]: '0.4.8'

In [15]: orbax.__version__
Out[15]: '0.1.7'

on: Kernel: x86_64 Linux 6.2.13-arch1-1

Running orbax from an existing asyncio event loop

Currently, the checkpoint loading code calls asyncio.run even when the synchronous checkpointer impl is being used. This is problematic when calling orbax from within an existing asyncio event loop, because this method cannot be called within an existing loop and the existing loop should be reused instead. Having an existing event loop seems to be a relatively common scenario when running within a webserver or, say, from jupyter, so it probably needs to be addressed.

orbax-export does not support arbitrary nested observations

Hi,

It seems that orbax.export currently only supports functions of the form fn(params, x) where x is Mapping[str, Array], is that correct? I have use cases in models where x can have multiple levels of nesting.

What would be the best approach for handling this with orbax. I found that I can save a SavedModel via a naive approach by tf.function tracing but somehow it fails with orbax. I am not entirely sure what the issue that's causing this.

best practice for saving / restoring the position in the dataset (iterator)

Hi,

I noticed here it's mentioned that a tf.data.Iterator could be checkpointed.

I tried this, but it didn't work:

import orbax.checkpoint
checkpointer = orbax.checkpoint.PyTreeCheckpointer()

ds = tf.data.Dataset.load(dataset_path)
ds = ds.shuffle(buffer_size=10000, seed=12345)
it = iter(ds)

# ... train for some number of iterations
checkpointer.save(checkpoint_path, { 'ds_iterator': it })

and then to restore I tried:

restored = checkpointer.restore(checkpoint_path)
print(restored)
# { 'ds_iterator': None }

In particular, how would the checkpointer "know" that when restoring, it needs to re-do the shuffle of the dataset, and with the particular random seed? Seems like none of that information could be held in the iterator, so this must not be the right way. Any advice appreciated!

Thanks,

Henry

Restore latest checkpoint

Hello,
is there available functionality to restore the latest checkpoint stored in a directory? I'm looking for something analogous to flax.checkpoints.latest_checkpoint. In particular, I would like to obtain the path to the latest checkpoint in a given directory.

Currently it is not super clear to me what happens when I pass a directory to CheckpointManager.restore. Do I automatically get the latest checkpoint? If so, is there a way to access its path?

Thanks!

Which weights does each worker load?

I wrote a pytorch to orbax checkpoint converter that saves each param to 1 file. I am concerned that each worker will load the full huge model to CPU RAM and die.

I have N questions:

  1. Is this the right mental model of what will happen if I try to load this checkpoint with large parallelism?
  • every worker pulls every full file from blob, slices the param, takes what it needs.
  • If I saved 4 files per param and loaded with 16 way parallelism every worker would still load 1 smaller file and take what it needs.
  • If I saved 4 files per param and loaded with 2 way parallelism every worker would load 2 smaller files and concatenate?
  1. If my mental model is reasonable it makes sense to improve my converter to save more files per parameter. Is there a way to simulate this parallelism without running the script on N devices? 1 idea that comes to mind is
TPU_ACCELERATOR_TYPE='' JAX_PLATFORM='cpu' XLA_FLAGS="--xla_force_host_platform_device_count=1024" python converter.py "$@"

on a big CPU machine.

Existing Unsharded Format

  • note these outputs do not include .zarray which is present.

(produced by converter+orbax)

        β”œβ”€β”€ params.classi.classi_DC.kernel
        β”‚Β Β  └── 0.0
        β”œβ”€β”€ params.layer_stack.decoder.mlp.in_DF.kernel
        β”‚Β Β  └── 0.0.0
        β”œβ”€β”€ params.layer_stack.decoder.mlp.out_FD.kernel
        β”‚Β Β  └── 0.0.0
        └── params.token_embedder.embedding_VD
            └── 0.0

Sharded format

(produced by maxtext+orbax)

        β”œβ”€β”€ params.layer_stack.decoder.mlp.in_DF.kernel
        β”‚Β Β  β”œβ”€β”€ 0.0.0
        β”‚Β Β  β”œβ”€β”€ 0.0.1
        β”‚Β Β  β”œβ”€β”€ 0.0.2
        β”‚Β Β  └── 0.0.3
        β”œβ”€β”€ params.layer_stack.decoder.mlp.out_FD.kernel
        β”‚Β Β  β”œβ”€β”€ 0.0.0
        β”‚Β Β  β”œβ”€β”€ 1.0.0
        β”‚Β Β  β”œβ”€β”€ 2.0.0
        β”‚Β Β  └── 3.0.0
        β”œβ”€β”€ params.token_embedder.embedding_VD
        β”‚Β Β  β”œβ”€β”€ 0.0
        β”‚Β Β  β”œβ”€β”€ 1.0
        β”‚Β Β  β”œβ”€β”€ 2.0
        β”‚Β Β  └── 3.0

If you copy checkpoints from HOME to gcs they can get deleted

because of this line

if is_gcs_path(path) and not (path / _COMMIT_SUCCESS_FILE).exists():

They don't have success file but are in GCS so orbax thinks its tmp and cleans it up.

I would suggest always or never saving COMMIT_SUCCESS file.

This is not blocking me (easy to just write extra commit success files once I found this) but it felt like I should report because it was very unexpected behavior and moving around checkpoints is super common.

Questions about loading and writing checkpoints in distributed training

Hi! I have a few questions about how orbax handles certain distributed training scenarios.

  • Imagine I have a model that is sharded over one mesh dimension (possibly over multiple nodes) while being replicated over another dimension. Is it true that, when writing a checkpoint, only one subset of nodes having a full replica of model weights will perform a writing operation? If so, what determines this subset?
  • Suppose my model is replicated over a number of nodes. When loading checkpoint, will each replica read weights independently in parallel?
    • If so, it might create a bottleneck if checkpoints are being read from a network filesystem with limited bandwidth. One alternative would be for only one replica to read the weights and then send them to other replicas using communication collectives, assuming inter-node network is fast. Is this something that can be done with orbax?

Thanks!

Orbax does not indicate its dependency versions, which easily leads to bugs

I am writing to point out that Orbax does not indicate its dependencies' versions.
This led me to a bug today on a Cloud TPU.

To be more specific, the recent the Orbax 0.1.2 release uses from jax.sharding import Mesh, which fails with Jax 0.3.25. If the dependency versions were indicated in the pyproject.toml, pip could figure out there is a version conflict, and this out-of-the-blue error would not appear. Hope to see this addressed, given the growing importance of Orbax!

As a workaround, I have downgraded to Orbax 0.1.1.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.