Git Product home page Git Product logo

optax's Introduction

Optax

CI status Documentation Status pypi

Introduction

Optax is a gradient processing and optimization library for JAX.

Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.

Our goals are to

  • Provide simple, well-tested, efficient implementations of core components.
  • Improve research productivity by enabling to easily combine low-level ingredients into custom optimisers (or other gradient processing components).
  • Accelerate adoption of new ideas by making it easy for anyone to contribute.

We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components in more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.

An initial prototype of this library was made available in JAX's experimental folder as jax.experimental.optix. Given the wide adoption across DeepMind of optix, and after a few iterations on the API, optix was eventually moved out of experimental as a standalone open-source library, and renamed optax.

Documentation on Optax can be found at optax.readthedocs.io.

Installation

You can install the latest released version of Optax from PyPI via:

pip install optax

or you can install the latest development version from GitHub:

pip install git+https://github.com/google-deepmind/optax.git

Quickstart

Optax contains implementations of many popular optimizers and loss functions. For example, the following code snippet uses the Adam optimizer from optax.adam and the mean squared error from optax.l2_loss. We initialize the optimizer state using the init function and params of the model.

optimizer = optax.adam(learning_rate)
# Obtain the `opt_state` that contains statistics for the optimizer.
params = {'w': jnp.ones((num_weights,))}
opt_state = optimizer.init(params)

To write the update loop we need a loss function that can be differentiated by Jax (with jax.grad in this example) to obtain the gradients.

compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
grads = jax.grad(compute_loss)(params, xs, ys)

The gradients are then converted via optimizer.update to obtain the updates that should be applied to the current parameters to obtain the new ones. optax.apply_updates is a convenience utility to do this.

updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

You can continue the quick start in the Optax ๐Ÿš€ Getting started notebook.

Development

We welcome new contributors.

Source code

You can check the latest sources with the following command.

git clone https://github.com/google-deepmind/optax.git

Testing

To run the tests, please execute the following script.

sh ./test.sh

Documentation

To build the documentation, first ensure that all the dependencies are installed.

pip install -e ".[docs]"

Then, execute the following.

cd docs/
make html

Benchmarks

If you feel lost in the crowd of available optimizers for deep learning, there exist some extensive benchmarks:

Benchmarking Neural Network Training Algorithms, Dahl G. et al, 2023,

Descending through a Crowded Valley โ€” Benchmarking Deep Learning Optimizers, Schmidt R. et al, 2021.

If you are interested in developing your own benchmark for some tasks, consider the following framework

Benchopt: Reproducible, efficient and collaborative optimization benchmarks, Moreau T. et al, 2022.

Finally, if you are searching for some recommendations on tuning optimizers, consider taking a look at

Deep Learning Tuning Playbook, Godbole V. et al, 2023.

Citing Optax

This repository is part of the DeepMind JAX Ecosystem, to cite Optax please use the citation:

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/google-deepmind},
  year = {2020},
}

optax's People

Contributors

8bitmp3 avatar acforvs avatar albcab avatar amosyou avatar atgctg avatar fabian-sp avatar fabianp avatar grantmcconachie avatar hawkinsp avatar hbq1 avatar hmludwig avatar holounic avatar joaogui1 avatar lukasmut avatar mblondel avatar mkunesch avatar mmhamdy avatar mtthss avatar n2cholas avatar rdaems avatar rosshemsley avatar sauravmaheshkar avatar stefanocortinovis avatar suryabhupa avatar tanaymeh avatar tomhennigan avatar v0lta avatar vroulet avatar vz415 avatar wdphy16 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

optax's Issues

Prevent unnecessary recompilation due to closures in optimizers

Now that Jax supports dataclasses as PyTrees, would it be possible to switch to using them instead of namedtuple? The benefits are explained here.

The biggest benefit would be preventing unnecessary recompilation. The current Optax code uses closures, which will cause Jax to unnecessarily recompile a jitted function that accepts a GradientTransformation. (The closures are different objects that hash differently, which means that changing the parameters to the GradientTransformation must cause the jitted function to recompile.)

A dataclass version of Optax would look something like this.

I am happy to do submit a pull request if this change is okay.

MultiSteps does not work with update functions that require parameters

The MultiSteps wrapper https://github.com/deepmind/optax/blob/30947fbc5743adc3e997c6242fa3775834862a74/optax/_src/wrappers.py#L179 for gradient accumulation does not feed parameters to the wrapped optimizer. These are needed for e.g. weight decay https://github.com/deepmind/optax/blob/30947fbc5743adc3e997c6242fa3775834862a74/optax/_src/transform.py#L547 which I think is used in almost every optimizer chain.

As far as I can tell every transformation accepts params, so is there any reason not to pipe them through?

Documentation

Hey, I am a user of optix, it seems optax is going to be the future of the library but I don't see any documentation. Is it too early to use this library?

Please consider supporting optimization of metaparamers

This is a very exciting project! I was just considering using flax.optim when I found optax, and I love the elegant combine.chain design of the varias optimizer aliases. Very cool!

I'd like to consider learning as an iterated function of the parameters, which itself depends on meta-parameters (e.g. learning rate). Then, I can use the fixed point theorem to calculate the gradient of the loss on a batch with respect to the metaparameters.

Unfortunately, optax's GradientTransformations are implemented using functions that close over values, which means that these values cannot be JAX tracers. From my understanding, you cannot take the derivative with respect to the step_size if the step size is a closed-over-value.

I know this might be a serious change, but would it be possible, instead of having:

def scale(step_size: float) -> GradientTransformation:
  ...
  return GradientTransformation(init_fn, update_fn)

To implement the abstract methods init_fn and update_fn in an inherited class:

class Scale(GradientTransformation):
  def __init__(self, step_size):
    ...

This design would allow:

  • taking derivatives with respect to various meta-parameters (like step_size),
  • inspecting objects (scale.step_size is available in the object oriented approach) for debugging,
  • comparing objects and preventing recompilation of jitted functions. If, for some reason, you call scale(1e-3) twice, you get a different object each time, and these objects will not compare equal. If these objects are passed to a jitted function, the function will be recompiled even though the objects would normally be equal.

JITted Adam results in NaN when setting decay to integer 0

This is due to a bug with integer exponentiation, leading to a divide by zero during bias correction on certain iteration multiples (e.g., 64 for b1=0). This issue is the closest I could get to the root cause. In the interim while the underlying issue is fixed, it could be guarded against in optax by casting the decay to a float in the bias correction helper.

Implement Differentially Private Stochastic Gradient Descent

Differentially Private SGD (https://cseweb.ucsd.edu/~kamalika/pubs/scs13.pdf) is an important algorithm in private machine learning. Essentially, it is SGD except you clip and add Gaussian noise to per-example gradients before averaging across the batch. I think this would be a useful addition to Optax.

The implementation could be based on the example in the JAX repo: https://github.com/google/jax/blob/master/examples/differentially_private_sgd.py

The usage would be slightly different from other transforms, since it requires per-example gradients as inputs. It can still be composed with other transforms as long as it is the first one in the chain. Alternatively, we can expose a stand-alone utility function that does the clipping/noise/aggregation that the user could then pass to a GradientTransform. I think the former option (making it a transform) is more convenient since this algorithm would have some state (the RNG key).

I'd be happy to work on this if it seems like a good addition.

Fix hyperlink rendering in Optax API docs on ReadTheDocs

Currently, the Markdown-flavored links (inside Python files) in Optax API docs (from Python doc strings) appear to not render well on the ReadTheDocs site.

For example, this (source: https://github.com/deepmind/optax/tree/master/optax/_src/alias.py#L322#L374):

def rmsprop(
      ...
) -> base.GradientTransformation:
  """A flexible RMSProp optimiser.
  ...
  References:
    [Tieleman and Hinton, 2012](
        www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
    [Graves, 2013](https://arxiv.org/abs/1308.0850)
    ...

translates into this (source: https://optax.readthedocs.io/en/latest/api.html#rmsprop):

image

Potential solutions:

  • Try reStructuredText-flavored formatting for links (since Sphinx likes rST) and see if this fixes the rendering:
`Text <URL>`_

(Source: https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html#hyperlinks)

  • Alternatively, avoid introducing rST in Python files and, instead, remove Markdown-flavored links ([text](URL)) to keep things simple, like in Haiku docs, since most links are from arXiv and they are quite short:
- [Graves, 2013](https://arxiv.org/abs/1308.0850)
+ Graves, 2013 https://arxiv.org/abs/1308.0850

which is similar to https://dm-haiku.readthedocs.io/en/latest/api.html#id1:

class Linear(hk.Module):
  """Linear module."""

  def __init__(
      ...
  ):
    """Constructs the Linear module.
    Args:
      ...
      w_init: Optional initializer for weights. By default, uses random values
        from truncated normal, with stddev ``1 / sqrt(fan_in)``. See
        https://arxiv.org/abs/1502.03167v3.

image

LMKWYT @mtthss

I can open a PR ๐Ÿ‘

Reduce_on_plateau

Hi,
in the learning rate schedule list it may be useful to add a "reduce on plateau" scheduler which looks during a certain number of epochs if the loss is decreasing, and if not divide the learning rate by a certain amount. (possibly clipped by a minimal value).
In Pytorch I experience that in turns out to be useful for some of my use-cases of optimizing CNN.

Weight normalization

Is there an equivalent to flax.optim.WeightNorm? As flax.optim is effectively deprecated in favor of optax, I would like to see it implemented in optax.

Implement log(cosh()) loss function

log-cosh is a doubly differentiable alternative to the huber loss. A naive implementation is prone to overflow (since cosh has an e^x term), so I think it'd be a useful addition to the library. Plus, it's implemented in other libraries, such as TensorFlow.

If this sounds like a relevant addition, I'd be happy to contribute it!

More examples

This issue provides a home for work on adding informative examples of using optax (for instance examples reproducing results from interesting optimisation papers).

Reach out on this issue if you are interested and/or have suggestions.

Import optax on Colab gives: cannot import name 'flags' from 'jax.config'

I understand that this seems to be a Colab notebook specific error. If not appropriate to raise issue here, would be happy to raise it elsewhere. :)

ImportError                               Traceback (most recent call last)

<ipython-input-9-72cd76e3a907> in <module>()
      4 from jax.experimental import maps
      5 import numpy as np
----> 6 import optax
      7 import transformers
      8 

6 frames

/usr/local/lib/python3.7/dist-packages/optax/__init__.py in <module>()
     16 """Optax: composable gradient processing and optimization, in JAX."""
     17 
---> 18 from optax._src.alias import adabelief
     19 from optax._src.alias import adagrad
     20 from optax._src.alias import adam

/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py in <module>()
     20 import jax.numpy as jnp
     21 
---> 22 from optax._src import combine
     23 from optax._src import privacy
     24 from optax._src import schedule

/usr/local/lib/python3.7/dist-packages/optax/_src/combine.py in <module>()
     16 """Flexibly compose gradient transformations."""
     17 
---> 18 from optax._src import transform
     19 GradientTransformation = transform.GradientTransformation
     20 

/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py in <module>()
     18 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
     19 
---> 20 import chex
     21 import jax
     22 import jax.numpy as jnp

/usr/local/lib/python3.7/dist-packages/chex/__init__.py in <module>()
     15 """Chex: Testing made fun, in JAX!"""
     16 
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_gt
     19 from chex._src.asserts import assert_devices_available

/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py in <module>()
     29 import jax
     30 import jax.numpy as jnp
---> 31 import jax.test_util as jax_test
     32 import numpy as np
     33 import tree as dm_tree

/usr/local/lib/python3.7/dist-packages/jax/test_util.py in <module>()
     33 from . import dtypes as _dtypes
     34 from . import lax
---> 35 from .config import flags, bool_env, config
     36 from ._src.util import partial, prod
     37 from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce

ImportError: cannot import name 'flags' from 'jax.config' (/usr/local/lib/python3.7/dist-packages/jax/config.py)

Stateless Transformation

Right now, there's some boilerplate for defining simple gradient transformations. If a user wants to implement custom weight decay, clipping, constraints, etc without any state, they still have to define a nested function with an init that does nothing, and handle the empty state.

I think it'd be convenient to provide a stateless transformation that accepts a function to apply to the updates and params. We can also do the jax.tree_multimap for the user by default.

weight_decay = optax.stateless(lambda g, p: g + 0.1 * p)

If the user wants to define a function that does the tree_multimap themselves:

def my_function(updates, params):
    return jax.tree_multimap(lambda g, p: ..., updates, params)

optim = optax.stateless(my_function, on_leaves=False)

In my view, this is a very clean way to implement simple stateless transformations. I think the on_leaves argument could use a better name, though. I'd be happy to implement this if it sounds reasonable.

Prevent creating unnecessary momentum variables

Currently, optax.sgd and optax.noisy_sgd unconditionally create momentum variables for the parameters, since both rely on trace. For optax.noisy_sgd, this is unnecessary since decay is always 0. For optax.sgd, this is unexpected since momentum=0 by default (and can be wasteful for large models).

optax.noisy_sgd should only require _scale_by_learning_rate (with a negation). optax.sgd could conditionally add trace if momentum > 0.

Below are the lines of code I'm referring to:

https://github.com/deepmind/optax/blob/ba0bc11d172054d65b4387ecae840c04e2bc7035/optax/_src/alias.py#L142-L148

https://github.com/deepmind/optax/blob/ba0bc11d172054d65b4387ecae840c04e2bc7035/optax/_src/alias.py#L105-L113

And here's where trace automatically creates it's state:

https://github.com/deepmind/optax/blob/ba0bc11d172054d65b4387ecae840c04e2bc7035/optax/_src/transform.py#L212-L213

[REQ] Conda recipe

Hi,
I'm the lead developer of NetKet, an established machine learning / quantum physics package.

We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version.
Since many physicists seem to use anaconda, we would also like to update our conda recipe.
However, since we depend on Optax, we would need Optax to have a Conda recipe.

Is that something you'd consider? I already contributed the work for Chex and have the Optax recipe ready to go.
I am willing to volunteer the recipe by myself. I just need another member of the flax-team to be listed as maintainer of the recipe.

The recipe itself will be low maintenance, as it will pick-up pypi releases automatically and release new versions unless errors arise.

cc @hbq1

__all__ is currently not useful in __init__.py

__all__ in the __init__.py overrides which functions are imported when a user does from optax import *. It seems like every function in that file should be exposed through a wildcard import, so there is no need for the __all__. Besides being redundant, having it creates opportunities for bugs: right now, many of the functions (e.g. maybe_update, keep_params_nonnegative) are imported but not exposed in __all__. I believe it should be removed, and would be happy to create a PR if that makes sense.

DPSGD example's dataset not shuffled between epochs

Since the dataset is turned into a list, the same batch order and batches are used for each epoch:

https://github.com/deepmind/optax/blob/30947fbc5743adc3e997c6242fa3775834862a74/examples/differentially_private_sgd.py#L131

Training loop:

https://github.com/deepmind/optax/blob/30947fbc5743adc3e997c6242fa3775834862a74/examples/differentially_private_sgd.py#L163-L168

The dataset should not be turned into a list at all, but there is a significant performance drop when the tf.data.Dataset is used directly (~4s/epoch with the tf.data.Dataset, ~0.7 seconds with the list). I was not able to improve this, so would appreciate if someone with tf.data expertise could take a look. Thanks!

Option to not decay bias with additive_weight_decay

Currently, additive_weight_decay will decay all the parameters. Jia et al. 2018 show it is beneficial not to decay the bias parameters (and only decay the weights). Many training examples implement this too, such as Flax's Imagenet example.

One way to support this is to add a decay_bias: bool argument to additive_weight_decay, then within the decay update:

updates = jax.tree_multimap(
    lambda g, p: g + weight_decay * p if p.ndim > 1 or decay_bias else g, 
    updates, params)

As far as I'm aware for common NN layers, the bias always has one dimension, so checking if p.ndim > 1 is sufficient (please correct me if I'm wrong).

I'd be happy to contribute this if it sounds reasonable.

EDIT: I realized batch norm scale parameters also have only one dimension, so this filter would wrongly include those. However, many training pipelines do not regularize batchnorm scale/bias (e.g. the reference imagenet implementation from mlperf).

Manually setting the learning_rate?

Hi there,

Is it possible to set the learning rate manually? e.g.

# Setup optimiser
opt_init, opt_update = optax.adam(learning_rate=1e-3)
opt_state = opt_init(params)

# Train
for epoch_num in range(10):
    # Compute gradients.
    grads = jax.grad(loss_function)(params, data)
    # Transform the gradients using the optimiser.
    updates, opt_state = opt_update(grads, opt_state)
    # Update parameters.
    params = optax.apply_updates(params, updates)

    # *** MY IDEA/INTENTION ***:
    if epoch_num == 5:
        opt_state.learning_rate = 1e-4  # does something like this exist?

Many thanks for any help, and for this fantastic lib! :)

More flexible type annotations for schedules.

Step variables schedules have type int instead of Union[float ,int].

This is because the schedules are used for controlling learning rate schedules from integer step counts.

Users have requested that we make it possible to use schedules in other contexts, where the input would no longer be an integer.

One option here would be to admit more general types in the typings, and then remove specific references to steps in the schedules.

Feature Proposal: multi-optimizer / multi-transform

Currently, to apply different optimizers to different sets of parameters, you need to construct multiple masks and chain them. For example, to optionally apply weight decay to some parameters and not to others:

tx = optax.chain(
    optax.masked(optax.adamw(0.01, weight_decay=0.0), 
                 mask=partial(jax.tree_map, lambda p: p.ndim == 1)),
    optax.masked(optax.adamw(0.01, weight_decay=1e-3), 
                 mask=partial(jax.tree_map, lambda p: p.ndim != 1)),
)

However, this has a few issues:

  1. When you have more than two groups, creating mask functions for them all becomes verbose
  2. The burden of verifying the masked groups are mutually exclusive falls on the user (which can be bothersome for groups > 2)
  3. This solution isn't obvious to a new user

I'd like to propose a multi_transform (or multi_transformation, as always naming is hard) that generalizes optax.masked:

def multi_transform(transforms: Sequence[GradientTransformation], 
                    partition: Union[PyTree, Callable[[base.Params], PyTree]]):
  ... 

Where partition is a pytree where the leaves contain an index (or a function that returns such a pytree given the parameters). The index corresponds to a GradientTransformation in transforms.

With this, our above example would be:

tx = optax.multi_transform(
    [optax.adamw(0.01, weight_decay=1e-3), optax.adamw(0.01, weight_decay=0.0)],
    partition=partial(jax.tree_map, lambda p: int(p.ndim == 1)))  # 0 for weight decay, 1 for no weight decay

This solves our above problems:

  1. You only need one pytree (function)
  2. The groups are mutually exclusive by design, since there's only one pytree with indices
  3. The name matches what users would expect from such a feature (due to similar naming in other libraries, e.g. Flax).

Minor details:

  • partition could just be partition_fn: Callable[[base.Params], PyTree] (I left it as partition to mirror the signature of masked).
  • partition could have a different name, perhaps indices.

cc: @mtthss @jheek @andsteing

Norm of complex numbers

When using clip_by_global_norm on gradients of complex parameters, it seems we need to change jnp.square(x) to x.conj() * x in the function global_norm in _src/linear_algebra.py.

How is the current status of complex number support in Optax? I'm using neural networks in quantum physics, and I'd be happy to help JAX community to enhance complex number support. I guess I'll encounter more problems about it and I'll report them then.

RMSProp does not match original Tensorflow impl

A lot of papers using RMSProp have hparams that were tuned with the original Tensorflow impl.

The Optax impl is missing the momentum option and initializes the rms value differently.

TF1 RMSProp (https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/training/rmsprop.py#L126)

  • has momentum option
  • initializes RMS value to ones

Keras RMSProp (https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/optimizer_v2/rmsprop.py#L35-L299)

  • has momentum option
  • initializes RMS values to zeros (as with here and PyTorch)

I've spent time replicating results of papers like EfficientNet (and related) in PyTorch and ended up using my own RMSProp impl that matches the TF1 variant (the PyTorch one does not either).

Change masked wrapper to use mask_fn instead of mask

Currently, optax.masked accepts a boolean mask that has the same structure as (or is a prefix of) the parameters pytree. This breaks the pattern of only requiring the parameters during init and not before. This issue is to discuss the possibility of changing the mask argument to mask_fn, which would be a function that takes a parameter pytree as input and returns a mask that has the same/prefix structure as the params.

One clear advantage of the mask_fn approach is users can define an optimizer independent of the model (as is the case with every other transformation in optax).

The current use pattern is still possible by passing in mask=lambda _: mask for a premade mask if desired.

This was originally proposed by @jheek in the New Optimizer API for Flax discussion.

cc @mtthss @andsteing

Extracting learning rate from optimizer state directly

Hi! I'm trying to extract the learning rate from an optax optimizer directly for logging to Tensorboard.

I know I could get it from my learning rate schedule object instead by passing in step, but we've previously run into situations where the optimizer step # and expected step # went out of sync (our fault, not optax's), so to be safe we'd like to get it directly from the optimizer object. In Tensorflow you can do self._optimizer._get_hyper('learning_rate') to access it since it gets logged via _set_hyper. Is there an easy way to do a similar thing in optax?

What does optax.mask do?

I don't understand what optax.mask does.

I would expect that the masked optimizer

optax.mask(optax.sgd(0.1), {"Dense": True, "bias": False})

would only apply the optimisation to sub-leafs of Dense and not optimise sub-leaves of bias.
Which means that the masked gradient should match the sgd one for Dense and be zero for bias.

However it seems to me that the masked updates are correct for sub-leafs of Dense (so where the mask is True, but they are the identity where the mask is False.

Is this intended behaviour? it seems rather strange to me.
I was trying to update only a subsets of the weights of my model but this was not working

MWE:

import jax.numpy as jnp
import jax
import optax

pars = {"Dense": {"kernel": jnp.zeros((2,3)), "bias": jnp.zeros((3))}, "bias":jnp.zeros(2)}
grad = jax.tree_map(jnp.ones_like, pars)

op = optax.masked(optax.sgd(0.1), {"Dense": True, "bias": False})

op_state = op.init(pars)

masked_updates, new_op_state = op.update(grad, op_state,  pars)
>>> masked_updates
{'Dense': {'bias': DeviceArray([-0.1, -0.1, -0.1], dtype=float32), 'kernel': DeviceArray([[-0.1, -0.1, -0.1],
             [-0.1, -0.1, -0.1]], dtype=float32)}, 'bias': DeviceArray([1., 1.], dtype=float32)}

How to combine parameters from different haiku modules?

In PyTorch, it is possible to have a single optimizer for different nn.Module sets of parameters, and there are various ways to combine different modules' parameters.

For example, from https://github.com/altosaar/variational-autoencoder/blob/dfb452b5421e9e5b97315c6420b8766ac86f3f4f/train_variational_autoencoder_pytorch.py#L216:

optimizer = torch.optim.RMSprop( list(model.parameters()) + list(variational.parameters()), lr=cfg.learning_rate, centered=True, )

What is the equivalent in optax? Is it chaining optimizers, or is such functionality not supported at this time, requiring different instances of optax optimizers, one per haiku.Module?

Thanks so much!

optax.mask does not play well with flax FrozenParams

The mask must be returned/given as a frozenDict, which is annoying.
I'm not sure this is really an optax bug... but could something be done to alleviate this?

This also shows up in multi_transform, where the fix is less obvious because it internally builds a dict, therefore the only way to make it work is to unfreeze the params before giving it to the optimiser, which is... inconvenient.

>>> import jax.numpy as jnp
>>> import jax
>>> import optax
>>> from flax.core import freeze, unfreeze
>>> 
>>> pars = freeze({"Dense": {"kernel": jnp.zeros((2,3)), "bias": jnp.zeros((3))}, "bias":jnp.zeros(2)})
>>> op = optax.masked(optax.sgd(0.1), {"Dense": True, "bias": False})
>>> op.init(pars)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/filippovicentini/Documents/pythonenvs/netket_env/lib64/python3.8/site-packages/optax/_src/wrappers.py", line 311, in init_fn
    flat_params = treedef.flatten_up_to(params)
ValueError: Expected dict, got FrozenDict({
    Dense: {
        kernel: DeviceArray([[0., 0., 0.],
                     [0., 0., 0.]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
    bias: DeviceArray([0., 0.], dtype=float32),
}).

Multiple optimizers using optax

Sorry if this already exists in optax as a feature, but how would you go about making a multi-optimizer (similar to Flax optim) that could use different learning rates for different parts of a network?

Specifically, I'm have a full model in Haiku with one learning rate for most of the parameters, but different learning rates for specific subsets. I can partition the params appropriately and create a separate optimizer for each subset, but ideally I'd like to maintain the simplicity of a single set of params and optimizer_state. Is there a common approach to this?

Unclear how to use the optax.dpsgd optimizer with pmap

Hi,

The optax.dpsgd optimizer is special in that it takes per-example gradients as input, and takes care of aggregating them.

There is currently no documentation of how this implementation is supposed to be used with multiple devices (using pmap), and there might be a few options to do it.

@n2cholas is this something you have thought about?

Add documentation of wrappers to README

The Readme currently does not mention wrappers and functionality like MultiSteps. Add documentation of this to the readme to make it easier to discover these features.

adam does not learn?

Hi, i use the optax to implement the following convent to classify the Mnist dataset. I wonder why it is not learning?

import itertools
import time

import haiku as hk
import jax
import jax.numpy as jnp
import numpy.random as npr
import optax
from examples import datasets
from jax import grad, jit, random
from jax.experimental import optimizers, stax
from jax.experimental.stax import (
    Dense,
    Flatten,
    GeneralConv,
    LogSoftmax,
    Relu,
    elementwise,
)


def net_fn(x) -> jnp.ndarray:
    """Standard LeNet-300-100 MLP network."""
    mlp = hk.Sequential(
        [
            hk.Conv2D(output_channels=16, kernel_shape=[5, 5], padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=[2, 2], strides=[2, 2], padding="VALID"),
            hk.Conv2D(output_channels=32, kernel_shape=[5, 5], padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=[2, 2], strides=[2, 2], padding="VALID"),
            hk.Flatten(),
            hk.Linear(10),
        ]
    )
    return mlp(x)


net = hk.without_apply_rng(hk.transform(net_fn))


def loss(params, batch):
    inputs, targets = batch
    preds = net.apply(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(net.apply(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)


if __name__ == "__main__":

    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    train_images, train_labels, test_images, test_labels = datasets.mnist()

    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield train_images[batch_idx].reshape(-1, 1, 28, 28), train_labels[
                    batch_idx
                ]

    batches = data_stream()

    optimizer = optax.chain(
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), optax.scale(-step_size)
    )

    @jit
    def update(params, optimizer_state, batch):
        grads = grad(loss)(params, batch)
        optim_update, optimizer_state = optimizer.update(grads, optimizer_state, params)
        params = optax.apply_updates(params, optim_update)
        return params, optimizer_state

    params = net.init(jax.random.PRNGKey(42), next(batches)[0])

    # params = init_params
    optimizer_state = optimizer.init(params)

    itercount = itertools.count()

    print("\nStarting training...")
    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            params, optimizer_state = update(params, optimizer_state, next(batches))

        epoch_time = time.time() - start_time

        train_acc = accuracy(
            params, (train_images.reshape(-1, 1, 28, 28), train_labels)
        )
        test_acc = accuracy(params, (test_images.reshape(-1, 1, 28, 28), test_labels))
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))

I got the results of the following:

Starting training...
Epoch 0 in 1.73 sec
Training set accuracy 0.13341666758060455
Test set accuracy 0.14079999923706055
Epoch 1 in 0.34 sec
Training set accuracy 0.09730000048875809
Test set accuracy 0.10090000182390213
Epoch 2 in 0.32 sec
Training set accuracy 0.09966666996479034
Test set accuracy 0.10300000756978989
Epoch 3 in 0.33 sec
Training set accuracy 0.12701666355133057
Test set accuracy 0.12640000879764557
Epoch 4 in 0.34 sec
Training set accuracy 0.13825000822544098
Test set accuracy 0.13690000772476196
Epoch 5 in 0.37 sec
Training set accuracy 0.11961666494607925
Test set accuracy 0.1193000078201294
Epoch 6 in 0.34 sec
Training set accuracy 0.16324999928474426
Test set accuracy 0.1599000096321106
Epoch 7 in 0.33 sec
Training set accuracy 0.1628333330154419
Test set accuracy 0.16090001165866852
Epoch 8 in 0.38 sec
Training set accuracy 0.16438333690166473
Test set accuracy 0.1623000055551529
Epoch 9 in 0.33 sec
Training set accuracy 0.10546667128801346
Test set accuracy 0.10450000315904617
Epoch 10 in 0.35 sec
Training set accuracy 0.1001666709780693
Test set accuracy 0.09920000284910202
Epoch 11 in 0.34 sec
Training set accuracy 0.09881667047739029
Test set accuracy 0.09830000251531601
Epoch 12 in 0.32 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09800000488758087
Epoch 13 in 0.33 sec
Training set accuracy 0.09881667047739029
Test set accuracy 0.09830000251531601
Epoch 14 in 0.34 sec
Training set accuracy 0.10068333148956299
Test set accuracy 0.10000000149011612
Epoch 15 in 0.35 sec

It would be nice if someone know the reason!

[RFC] Proposal for complex-valued optimization in Optax

As mentioned in #144 , we would like to help enhance the support for complex numbers in the JAX community. The detailed proposal document is here: https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29

In brief, we need to properly implement the norm of complex variables in the optimizers. We need to decide whether to implement the complex norm, the split real norm, or both of them.

Although there is also a comment zone below the gist, I would like to keep the discussion in this issue thread. Feel free to leave your comments!

Support specifying end_value for exponential_decay

It would be convenient to support specifying an end_value for exponential decay. So, the exponential_decay would require init_value and transition_steps and one of (end_value or decay_rate). Here is an example of what I'm proposing.

I would be happy to contribute this, if you think it is suitable.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.