Git Product home page Git Product logo

praxis's Introduction

Praxis

What is Praxis? Praxis is the layer library for Pax. While Praxis is optimized for ML at scale, Praxis has a goal to be usable by other JAX-based ML projects.

Some examples of layers to be folded into Praxis are in the praxis/layers/ directory.

Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

praxis's People

Contributors

a9isha avatar aaroey avatar ashors1 avatar bignamehyp avatar cdh4696 avatar changlan avatar dhr avatar dryman avatar edloper avatar hawkinsp avatar jianlijianli avatar jihwanlee-alphago avatar jysohn23 avatar kaixih avatar laurentes avatar m-orsini avatar phoenix-meadowlark avatar ppwwyyxx avatar protoget avatar rohan-anil avatar royaurko avatar rybakov avatar saeta avatar shivaniag avatar tink-expo avatar ukoxyz avatar vlad17 avatar yashk2810 avatar zhangqiaorjc avatar zhangyujing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

praxis's Issues

Support custom FP8 dtype in Pipelined Transformer

We have submitted two PRs to introduce a new custom data type for FP8 params, also known as OWG params, in this PR and this PR. The purpose of this custom data type is primarily for custom gradient accumulation using the max operation.

After the merger of the aforementioned PRs, we still require one additional change, likely to the LayerwiseShardablePipelined, to perform the type conversion outside the scan_fn. This is necessary because the custom data type needs to be recognized before being broadcast into the iterations within the scan_fn to ensure that autograd correctly applies the custom gradient accumulation.

I have prepared a self-contained Python code for this potential change, which you can find here.

Essentially, you can disregard the lines before line 199 as if they have already been merged. Line 243 represents the proposed dtype conversion to be added to LayerwiseShardablePipelined, where we convert all OWG params into the custom data type.

However, there is an issue regarding how to obtain the mask of the OWG params. As per my understanding, OWG params physically reside in the PARAMS category, and we have weight hparams to determine if they are OWG or not. However, such weight hparams seem inaccessible inside the LayerwiseShardablePipelined. In the provided code, I compute the owg_mask outside in line 263 and pass it as an input to the model.apply in line 263. Nevertheless, I feel this is not an ideal design since it modifies the model call signature and is specific only to the FP8 scenario.

Ideally, I believe that if we can compute the owg_mask inside the layer (similar to line 226) by accessing the weight hparams, that would be preferable. I've observed a similar example with bf16_accum_in_fp32 here, although it doesn't require any weight hparams.

To sum up, what is the best practice to obtain the owg_mask inside the LayerwiseShardablePipelined where the weight hparams are not available?

(Note, to run the gist code, you need the latest jax build like 0.4.24.devxxxxx)

cc. @zhangqiaorjc

Praxis layers don't support user-specified collection names

We noticed that the praxis won't allocate the variables into the user-specified collection and instead the variables will stay in the default params. For example, in the following script, we have a custom layer Foo where we would like the variable input_scale to be created into fp8_params collection. However the output will show XXX vars {'params': {'input_scale': None, 'w': None}}, meaning the input_scale is put into the params collection. In contrast, what we expect is XXX vars {'fp8_params': {'input_scale': None}, 'params': {'w': None}}.

The motivation and the use case is that we need to maintain a set of variables for the fp8 support. And the updating of such variables needs a special process: (1) we use the custom_vjp mechanism to define how the grads of these variables are computed and the grads are basically the new variables (2) during the apply grads, we use these grads to replace the variables. To facilitate this, we would like to declare a new collection to keep such variables.

I have created a branch to fix the above issue here. But it is very specific to our use case. So, I am wondering any idea or suggestion about how to improve this?

cc. @pjannaty @nluehr @reedwm

from typing import Optional

from jax import lax
from jax import numpy as jnp
from jax import random
import jax._src.test_util as jtu

from praxis import base_layer
from praxis import pax_fiddle
from praxis import pytypes

instantiate = base_layer.instantiate
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
template_field = base_layer.template_field
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
JTensor = pytypes.JTensor

class Dot(base_layer.BaseLayer):
  """Wrapper around lax.dot used in standard Pax layers."""

  def __call__(self, lhs: JTensor, rhs: JTensor) -> JTensor:
    return lax.dot(lhs, rhs)

class Foo(base_layer.BaseLayer):
  input_dims: int = 0
  output_dims: int = 0
  weight_init: Optional[WeightInit] = None
  dot_tpl: LayerTpl = template_field(Dot)

  def setup(self) -> None:
    wp = self.weight_split_dims_mapping
    self.create_variable(
        'w',
        WeightHParams(
            shape=[self.input_dims, self.output_dims],
            init=self.weight_init,
            mesh_shape=self.mesh_shape,
            tensor_split_dims_mapping=wp.wt,
        ),
    )

    scale_args = {
        'shape': [1],
        'init': WeightInit.Constant(1.0),
        'dtype': jnp.float32,
        'mesh_shape': self.mesh_shape,
        'tensor_split_dims_mapping': None,
        'collections': ['fp8_params'],
    }
    self.create_variable('input_scale', WeightHParams(**scale_args))

    self.create_child('dot', self.dot_tpl.clone())

  def __call__(self, inputs: JTensor) -> JTensor:
    """Apply projection to inputs.

    Args:
      inputs: The inputs JTensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
    ap = self.activation_split_dims_mapping

    original_shape = inputs.shape
    assert len(original_shape) >= 2

    inputs = jnp.asarray(inputs, self.dtype)
    kernel = jnp.asarray(self.theta.w, self.dtype)

    # Reshape the inputs to 2D matrix.
    inp_mat = jnp.reshape(inputs,
                          (-1, self.input_dims))

    inp_mat = inp_mat * self.theta.input_scale

    # Actual dense layer math.
    out = self.dot(inp_mat, kernel)

    # Reshape back the outputs.
    out = jnp.reshape(out, (*original_shape[0:-1], self.output_dims))

    return out

in_size, out_size = 16, 32

prng_key = random.PRNGKey(seed=123)
prng_key, init_key, random_key = random.split(prng_key, 3)
inputs = random.uniform(random_key, (48, in_size)).astype(jnp.bfloat16)

foo_kwargs = {'input_dims': in_size, 'output_dims': out_size,
              'dtype': jnp.bfloat16}
foo: Foo = instantiate(
    pax_fiddle.Config(Foo, name='foo', **foo_kwargs)
)

variables = foo.init(init_key, inputs)
var_tree = jtu.tree_map(lambda x: None, variables)
print("XXX vars", var_tree)

Incorrect conversion from tf dtype to jax dtype

In class DatasetInputSpecsProvider when converting tf specs to jax

dtype=spec.dtype.as_numpy_dtype())

as_numpy_dtype is considered as a method when it is actually an attribute of tf.dtypes.Dtype (https://www.tensorflow.org/api_docs/python/tf/dtypes/DType#attributes).

The code works for most dtypes but fails to do for tf.string as the returned entity is a pointer to the object np datatype and not the object datatype itself.

Any publicly available document?

Hi Praxis team, I have been using Jax and Flax for quite sometime before find out Praxis. Flax was great, however, not well suited for scaling. I also checked T5X and Flaxformer but it seems like they are not very developer friendly as the main functionality is defining transformer layers. Me and my collegues would love to move to Praxis in our future work. I wonder if there are any available document that we can use?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.