Git Product home page Git Product logo

safejax's Introduction

๐Ÿ” Serialize JAX, Flax, Haiku, or Objax model params with safetensors

safejax is a Python package to serialize JAX, Flax, Haiku, or Objax model params using safetensors as the tensor storage format, instead of relying on pickle. For more details on why safetensors is safer than pickle please check huggingface/safetensors.

Note that safejax supports the serialization of jax, flax, dm-haiku, and objax model parameters and has been tested with all those frameworks, but there may be some cases where it does not work as expected, as this is still in an early development phase, so please if you have any feedback or bug reports, open an issue at safejax/issues.

๐Ÿ› ๏ธ Requirements & Installation

safejax requires Python 3.7 or above

pip install safejax --upgrade

๐Ÿ’ป Usage

flax

  • Convert params to bytes in memory

    from safejax.flax import serialize, deserialize
    
    params = model.init(...)
    
    encoded_bytes = serialize(params)
    decoded_params = deserialize(encoded_bytes)
    
    model.apply(decoded_params, ...)
  • Convert params to bytes in params.safetensors file

    from safejax.flax import serialize, deserialize
    
    params = model.init(...)
    
    encoded_bytes = serialize(params, filename="./params.safetensors")
    decoded_params = deserialize("./params.safetensors")
    
    model.apply(decoded_params, ...)

dm-haiku

  • Just contains params

    from safejax.haiku import serialize, deserialize
    
    params = model.init(...)
    
    encoded_bytes = serialize(params)
    decoded_params = deserialize(encoded_bytes)
    
    model.apply(decoded_params, ...)
  • If it contains params and state e.g. ExponentialMovingAverage in BatchNorm

    from safejax.haiku import serialize, deserialize
    
    params, state = model.init(...)
    params_state = {"params": params, "state": state}
    
    encoded_bytes = serialize(params_state)
    decoded_params_state = deserialize(encoded_bytes) # .keys() contains `params` and `state`
    
    model.apply(decoded_params_state["params"], decoded_params_state["state"], ...)
  • If it contains params and state, but we want to serialize those individually

    from safejax.haiku import serialize, deserialize
    
    params, state = model.init(...)
    
    encoded_bytes = serialize(params)
    decoded_params = deserialize(encoded_bytes)
    
    encoded_bytes = serialize(state)
    decoded_state = deserialize(encoded_bytes)
    
    model.apply(decoded_params, decoded_state, ...)

objax

  • Convert params to bytes in memory, and convert back to VarCollection

    from safejax.objax import serialize, deserialize
    
    params = model.vars()
    
    encoded_bytes = serialize(params=params)
    decoded_params = deserialize(encoded_bytes)
    
    for key, value in decoded_params.items():
      if key in model.vars():
        model.vars()[key].assign(value.value)
    
    model(...)
  • Convert params to bytes in params.safetensors file

    from safejax.objax import serialize, deserialize
    
    params = model.vars()
    
    encoded_bytes = serialize(params=params, filename="./params.safetensors")
    decoded_params = deserialize("./params.safetensors")
    
    for key, value in decoded_params.items():
      if key in model.vars():
        model.vars()[key].assign(value.value)
    
    model(...)
  • Convert params to bytes in params.safetensors and assign during deserialization

    from safejax.objax import serialize, deserialize_with_assignment
    
    params = model.vars()
    
    encoded_bytes = serialize(params=params, filename="./params.safetensors")
    deserialize_with_assignment(filename="./params.safetensors", model_vars=params)
    
    model(...)

More in-detail examples can be found at examples/ for flax, dm-haiku, and objax.

๐Ÿค” Why safejax?

safetensors defines an easy and fast (zero-copy) format to store tensors, while pickle has some known weaknesses and security issues. safetensors is also a storage format that is intended to be trivial to the framework used to load the tensors. More in-depth information can be found at huggingface/safetensors.

jax uses pytrees to store the model parameters in memory, so it's a dictionary-like class containing nested jnp.DeviceArray tensors.

dm-haiku uses a custom dictionary formatted as <level_1>/~/<level_2>, where the levels are the ones that define the tree structure and /~/ is the separator between those e.g. res_net50/~/intial_conv, and that key does not contain a jnp.DeviceArray, but a dictionary with key value pairs e.g. for both weights as w and biases as b.

objax defines a custom dictionary-like class named VarCollection that contains some variables inheriting from BaseVar which is another custom objax type.

flax defines a dictionary-like class named FrozenDict that is used to store the tensors in memory, it can be dumped either into bytes in MessagePack format or as a state_dict.

There are no plans from HuggingFace to extend safetensors to support anything more than tensors e.g. FrozenDicts, see their response at huggingface/safetensors/discussions/138.

So the motivation to create safejax is to easily provide a way to serialize FrozenDicts using safetensors as the tensor storage format instead of pickle, as well as to provide a common and easy way to serialize and deserialize any JAX model params (Flax, Haiku, or Objax) using safetensors format.

safejax's People

Contributors

alvarobartt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

safejax's Issues

Move `load.py` and `save.py` to `core/` and define `partial` functions for `flax`, `haiku`, and `objax`?

Maybe it's cleaner to have one Python file per framework i.e. flax.py, objax.py, and haiku.py, and import both save and load functions from safejax.core so that the functions defined in each of the "framework" files are functools.partial with predefined values for parameters that are expected to be one way or another depending on the framework.

e.g. deserialize(..., freeze_dict=True) when importing deserialize from safejax.flax

Define custom type in `safejax.typing` for model params

Something like:

ParamsLike = Union[Dict[str, jnp.DeviceArray], Dict[str, np.ndarray], FrozenDict, VarCollection]

To avoid defining all the types over and over, and also so that the type-hints are aligned between both safejax.serialize and safejax.deserialize ๐Ÿค—

Incompatibility issue with flax

importing the safejax library I get the following warning from flax
FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys() instead.

Add unit tests for `objax` and `haiku` model params

Currently, the unit tests of safejax are just covering the basic functionality and in a poor way, so those could be improved, and one of the things to improve is to extend the unit tests to make sure that safejax works as expected on the serialization/deserialization of any model of the supported frameworks i.e. jax, flax, objax and haiku

Upgrade orbax deps?

Installing with pip3 fails with the following error

      *** Orbax is a namespace, and not a standalone package. For model checkpointing and exporting utilities, please install `orbax-checkpoint` and `orbax-export` respectively (instead of `orbax`). ***

Add `pip-tools` to lock the dependencies

That's it, just to add pip-tools to lock the dependencies as hatchling is just the build backend and doesn't lock the dependencies, while poetry does everything, so the replacement of poetry should be hatchling + pip-tools, not just hatchling.

Explore integration with `objax`

objax uses a custom naming for the layers, as it's appending some extra information between parentheses to ease the debugging. Anyway, they already worked on something to ease the renaming of the variables in VarCollection, see google/objax#85, in case we wanted to align the format of the keys of the stored dictionaries with tensors.

e.g. (EfficientNet).stem(ConvBnAct).conv(Conv2d).w in objax from efficientnet.vars()

This needs to be extensively tested, but if the model.vars() are already defined this way, then there's no need to unflatten the dict coming from encoded bytes using safetensors, as the default structure is already prepared to be formatted as a dictionary with the key being the layer name (whether it contains extra information or not) and the value is a jnp.DeviceArray, so it doesn't need to be neither flattened nor unflattened, as it's already done.

Maybe we could split both serialize and deserialize to have one behavior or another depending on the framework you're using, maybe with function params and partial function definitions, not sure yet.

In case you're curious, @rwightman also reported some naming differences in some layer conventions compared to both TF/Keras and PyTorch at google/objax#104, even though this doesn't affect to this issue, as for the serialization the naming convention is not relevant.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.