Git Product home page Git Product logo

jaxtyping's Introduction

jaxtyping

Type annotations and runtime type-checking for:

  1. shape and dtype of JAX arrays; (Now also supports PyTorch, NumPy, and TensorFlow!)
  2. PyTrees.

For example:

from jaxtyping import Array, Float, PyTree

# Accepts floating-point 2D arrays with matching axes
def matrix_multiply(x: Float[Array, "dim1 dim2"],
                    y: Float[Array, "dim2 dim3"]
                  ) -> Float[Array, "dim1 dim3"]:
    ...

def accepts_pytree_of_ints(x: PyTree[int]):
    ...

def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
    ...

Installation

pip install jaxtyping

Requires Python 3.9+.

JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc.

The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are typeguard (which exhaustively checks every argument) and beartype (which checks random pieces of arguments).

Documentation

Available at https://docs.kidger.site/jaxtyping.

Finally

See also: other libraries in the JAX ecosystem

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Disclaimer

This is not an official Google product.

jaxtyping's People

Contributors

afrozenator avatar ar0ck avatar asford avatar brentyi avatar ebrevdo avatar jeertmans avatar jianlijianli avatar knyazer avatar murphyk avatar patrick-kidger avatar peterroelants avatar zaccranko avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jaxtyping's Issues

TypeCheckError: argument "<arg name>" (jaxlib.xla_extension.ArrayImpl) is not an instance of jax._src.prng.PRNGKeyArray

Hi, I'm having this issue in my JAX project. I'm using pyright as a static type checker and jaxtyping + typeguard for the runtime. The problem is that it seems like these two don't agree on what's the type of a random.PRNGKey. Pyright is ok if I annotate it as random.PRNGKeyArray (or random.KeyArray), but this falls at runtime due to jaxtyping. On the other hand, jaxtyping seems to be ok only if I annotate it as UInt[Array, "2"] (but pyright is not).

Right now I solved it by defining:

Key = UInt[Array, "2"] | KeyArray

But I was wondering if there is a better way.

Perform addition/multiplication on shape information

I would be amazing if the following was possible

@jaxtyped
@beartype
def double_arr(arr: f["dim"]) -> f["dim+dim"]:
    return jnp.hstack((arr, arr))
# or 
@jaxtyped
@beartype
def double_arr(arr: f["dim"]) -> f["dim*2"]:
    return jnp.hstack((arr, arr))

Vanilla dataclasses don't work

Hi!

This condition:
https://github.com/google/jaxtyping/blob/59e8fb0d18325f990a9d59ee35e90c04b699cab8/jaxtyping/decorator.py#L71

Returns True for equinox modules, but False for vanilla dataclasses. Here's a reproduction: https://github.com/brentyi/jaxtyping-export-repro/tree/broken_dataclasses

Perhaps the decorator here should be prepended rather than appended? Otherwise, _jaxtyped_typechecker will be called before @dataclass.
https://github.com/google/jaxtyping/blob/59e8fb0d18325f990a9d59ee35e90c04b699cab8/jaxtyping/import_hook.py#L114

( if this seems like a valid/trivial fix I'm happy to make a PR, maybe this weekend )

Batched type

Hi, I was wondering if a type-annotation for Batched datastructures exists or if it could be implemented?

So something similar along the likes of PyTree but for indicating that all leaf nodes have a leading axis size of e.g., N.

I know I can already do this easily with e.g., PyTree[Float32[Array, 'N ...']], but in this way I can't use my type-aliases (or don't know how to).

Example use-case:

from jaxtyping import Array, Float32, PyTree
from jax import vmap

MyArray = Float32[Array, '...']
Batched = ???  # Should prepend 'N ' to Float32[Array, 'N ...']

def sample_fun(x: PyTree[MyArray]) -> PyTree[MyArray]:
    return x

def batch_fun(xs: Batched[PyTree[MyArray]]) -> PyTree[Batched[MyArray]]:
    return jax.vmap(sample_fun)(xs)

Note, intuitively Batched[PyTree[...]] should be equivalent to PyTree[Batched[...]] as they should operate on the leaves.

pip installing jaxtyping prints an error in colab

%pip install jaxtyping throws an error in Colab. The installation still succeeds, but it prints this scary block of red text:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.1.5 requires typing-extensions<4.2.0,>=3.7.4.1; python_version < "3.8", but you have typing-extensions 4.4.0 which is incompatible.
spacy 3.4.2 requires typing-extensions<4.2.0,>=3.7.4; python_version < "3.8", but you have typing-extensions 4.4.0 which is incompatible.
confection 0.0.3 requires typing-extensions<4.2.0,>=3.7.4.1; python_version < "3.8", but you have typing-extensions 4.4.0 which is incompatible.

We use jaxtyping as a dependency for dynamax, so pip installing our package produces this error message as well. tagging @murphyk as well.

Possible bug when using `torch.nn.Module` and `@jaxtyping` + `@typechecker`

I could be missing something, but I think there may be something problematic happening when decorating functions with @jaxtyping and @typechecker.

Minimal code example showing that when using the decorators @jaxtyping and @typechecker where typechecker is beartype, forward() is no longer a types.MethodType. This causes bugs when trying to train a torch.compile(model) because it needs to assert model.forward is of type types.MethodType.

import torch
import types
import torch.nn as nn
from jaxtyping import jaxtyped, Float
from beartype import beartype as typechecker

class model_without_jaxtyping(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x
    
model_without_jaxtyping_instance = model_without_jaxtyping()
print(type(model_without_jaxtyping_instance.forward) == types.MethodType)
# 'True'

class model_with_jaxtyping(nn.Module):
    def __init__(self):
        super().__init__()

    @jaxtyped
    @typechecker
    def forward(self, x: Float[torch.Tensor, "..."]):
        return x
    
model_with_jaxtyping_instance = model_with_jaxtyping()
print(type(model_with_jaxtyping_instance.forward) == types.MethodType)
# 'False'

mypy error: Unsupported operand type for unary - ("Array")

Using mypy to typecheck the following piece of code:

import jax.numpy as jnp
from jaxtyping import Array, Float

def softplus_inv(x: Float[Array, "x"]) -> Float[Array, "x"]:
    """Inverse to softplus"""
    return x + jnp.log(-jnp.expm1(-x))

Results in:

error: Unsupported operand type for unary - ("Array")

Am I doing something wrong? Other parts of my code seem to pass the type checks ok. It just seems that this unary minus - seems to cause issues.

Tested on:

  • Python 3.10
  • jaxtyping 0.2.5
  • jax 0.3.17
  • mypy 0.971

Export Array from jax without installation

Thanks so much for this super useful package--it has been really useful for me. I am wondering if you might consider decoupling the jax.Array import from the main jax library? I realize this would be a lot of work but to the degree that I'd like to use jaxtyping for projects that may not actually use jax code in general it might increase adoption.

Forward reference name is not defined error when static type checking with mypy

Because strings in types are typically interpreted as forward references I'm getting Name "*" is not defined errors using mypy.

I know from the documentation that to deal with similar issues related to flake8 checks it to prepend a space, and ignore F722.

For example:

def test_error() -> Float[Array, " x"]:
    ...

Is resulting in: error: Name "x" is not defined.

Are there any recommendations to avoid thise forward reference errors when static type checking?

I'm using:

  • Python 3.10
  • mypy 0.971
  • jax 0.3.17
  • jaxtyping 0.2.4

Evaluate using annotations with typeguard-style instrumentation

Hey @patrick-kidger!

I got this idea through our discussion in agronholm/typeguard#353. Wanted to create an issue here since it's unrelated to typeguard.

What if jaxtyping had its own typeguard-style instrumentation?

We could write annotations like Float[T, S] = Annotated[T, Shape(S), DType("Float")], and the instrumentation could consume them to validate.

Pros

  • More control over validation, solves #6
  • Seems more in the spirit of PEP 593
  • Gets rid of the Annotated as Float hack

Cons

  • It's instrumentation
  • Could be overkill

Using `jaxtyping` with torch

Hi,

First off, thanks for the great library!

I'm really interested in using jaxtyping-style syntax with torch, and was wondering:

  • Are there any plans to backport the recent syntax changes to torchtyping?
    • torchtyping is great as-is, but I've been holding off on using it various projects because of incompatibility with static checkers. The recent shift to the typing.Annotated-style syntax seems to fix that.
  • And if not, would it be possible to release jaxtyping without a depedency on jax? It seems like it supports torch pretty well already, I'd just prefer to not ship torch projects with a jax dependency.

Thanks again!!

ImportError: cannot import name 'Array' from 'jaxtyping'

import jaxtyping
print(jaxtyping.__version__) # returns 0.2.14
# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float32, jaxtyped

returns

ImportError: cannot import name 'Array' from 'jaxtyping' (/home/jovyan/conda/lib/python3.8/site-packages/jaxtyping/__init__.py)

mypy type checking seems to break in strict mode -- a mypy bug?

Following up on patrick-kidger/torchtyping#41 I'm trying the same things here. However I'm not really having a big success so far with mypy. Am I doing anything wrong?

import torch
from jaxtyping import Float

dim1 = "dim1"


# Expected to work but fails with:
# error: Returning Any from function declared to return "Tensor"
def simple_test_a(x: Float[torch.Tensor, "dim1"]) -> torch.Tensor:
    return x


# Expected to work but fails with:
# error: Returning Any from function declared to return "float"
def simple_test_b(x: Float[torch.Tensor, "dim1"]) -> float:
    return x.item()


# Expected to error, but passes type checking
def simple_test_c(x: Float[torch.Tensor, "dim1"]) -> None:
    x.asdfasdfasdf()

VSCode (pyright) seems to do a little better, but apparently doesn't like the import:

image

Incompatibility with latest typeguard version

I was updating my typeguard version to the latest release on Github: https://github.com/agronholm/typeguard (4.0.0rc5), and now get conflicts when trying to shape annotate torch Tensors (the code was working fine with previous typeguard versions). In particular running the following simple snippet:

from typeguard import typechecked
from jaxtyping import jaxtyped, Float
from torch import Tensor

@jaxtyped
@typechecked
def foo(x: Float[Tensor, 'a b c']):
    return 1

Will throw this error:

  File "/nfs/homedirs/fuchsgru/graph-active-learning/foo.py", line 6, in <module>
    @typechecked
     ^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_decorators.py", line 213, in typechecked
    retval = instrument(target)
             ^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_decorators.py", line 54, in instrument
    instrumentor.visit(module_ast)
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 561, in visit_Module
    self.generic_visit(node)
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 494, in generic_visit
    value = self.visit(value)
            ^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 672, in visit_FunctionDef
    annotation = self._convert_annotation(deepcopy(arg.annotation))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 546, in _convert_annotation
    new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 339, in visit
    new_node = super().visit(node)
               ^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 391, in visit_Subscript
    items = [self.visit(item) for item in slice_value.elts]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 391, in <listcomp>
    items = [self.visit(item) for item in slice_value.elts]
             ^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 339, in visit
    new_node = super().visit(node)
               ^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 442, in visit_Constant
    expression = ast.parse(node.value, mode="eval")
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 50, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    a b c
      ^
SyntaxError: invalid syntax

Seems like typeguard is parsing the annotation "a b c" with ast, which does not like this format. Is there any workaround for this? I need the latest typeguard version as earlier versions do not seem to support annotating @property decorated functions.

Relevant versions of typeguard and `jaxtyping:

typeguard                4.0.0rc5.post1
jaxtyping                0.2.15

str: int option?

Hi,

torchtyping has the option for a str: int pair, e.g.:

x: TT["cols": 3, "rows": 4]

which was helpful if you wanted to both (1) name the dimensions and (2) fix them to specific values. Does jaxtyping have a similar feature? If it does, I can't find it currently. (e.g. it seems like Float[t.Tensor, "cols=3 rows=4"] (or : rather than =) would be quite natural here, but that doesn't seem to be allowed).

Thanks for the awesome library! (-:

Ruff Compatibility

Hi! I'm trying to use Ruff (https://github.com/charliermarsh/ruff) as my Python linter, but getting some weird rule errors when plugging in to jaxtyping.

Here's a minimal example:

import numpy as np
from jaxtyping import Float32

def hello(x: Float32[np.ndarray, "batch dim"]) -> Float32[np.ndarray, "batch"]:
    return x[:, 1]

if __name__ == "__main__":
    print(hello(np.random.randn(10, 8)))

Running ruff minimal.py (the above file) results in two errors:

  • F722 "Syntax error in forward annotation: batch dim"
  • F821 "Undefined name batch"

Not sure if this is a jaxtyping or a ruff problem, but saw that jaxtyping is using Ruff to lint now, so curious if this has come up!

Array vs PyTree anotation

Hi!

I'm curious why the API for array annotations is Float[Array, "dims"] and not Array[float, "dims"]? The latter would make it consistent with PyTree[float].

Array vs PyTree anotation

Hi!

I'm curious why the API for array annotations is Float[Array, "dims"] and not Array[float, "dims"]? The latter would make it consistent with PyTree[float].

Single / multiple dispatch with jaxtyping types

Is there a way to use jaxtyping types with functools.singledispatch or the multipledispatch library? I tried with the following code with no success:

import functools as ft
import jax.numpy as jnp
from jaxtyping import Array, Shaped, jaxtyped

@jaxtyped
@ft.singledispatch
def f(image):
    raise NotImplementedError(f"{type(image)} not impemented")

@jaxtyped
@f.register
def _(image: Shaped[Array, "h w 4"]):
    # process a RGBA image
    pass

@jaxtyped
@f.register
def _(image: Shaped[Array, "h w 3"]):
    # process a RGB image
    pass

@jaxtyped
@f.register
def _(image: Shaped[Array, "h w"]):
    # process a grayscale image
    pass

f(jnp.ones((100, 100, 3)))  # NotImplementedError: <class 'jaxlib.xla_extension.Array'> not impemented

Feature wish, shape functions

In doing annotation over a large library I found myself unable to annotate some matrix operations fully. For example, the library contains an SVD function and the output size depends on the MIN of two input matrix dimensions. (See e.g. https://en.wikipedia.org/wiki/Singular_value_decomposition). So, for example for their SVD function what I would like to be able to write is:

def svd(
        self: Float[LinearOperator, "*batch M N"]
    ) -> Tuple[Float[LinearOperator, "*batch M M"], Float[Tensor, "... min(M,N)"], Float[LinearOperator, "*batch N N"]]

But, there is not the notion of a 'min' function over the dimensions.

Similarly, the library I was annotating has many places where it will optionally broadcast over a "batch" of matricies. So, a matrix multiplication like A*B may contain a batch dimension in either A or B. In this case, I would like to be able to check the batch size of the output. Something like what is listed below where I look for a non-empty batch parameter:

def _matmul(
        self: Float[LinearOperator, "*batch M N"],
        rhs: Float[torch.Tensor, "*batch2 N C"],
    ) -> Float[torch.Tensor, "coalesce(batch, batch2) M C"]:

Add `dtype` type

Hi, I wanted to ask if there could be a type annotation/ alias for specifying compatible type arguments.

Maybe this already exists, but I'm not yet aware of it.

For example:

import typing
import jaxtyping

import jax


# Right now: `dtype = ??` or jax._src.typing defines `dtype = Union[str, np.dtype, Any, SupportsDType]`.
def myfun(shape: int | typing.Sequence[int], dtype: ...) -> jaxtyping.Num[jaxtyping.Array, '...']:
    return jnp.zeros(shape, dtype)


# Desired
def myfun(shape: int | typing.Sequence[int], dtype: jaxtyping.JaxDTypeLike) -> jaxtyping.Num[jaxtyping.Array, '...']:
    return jnp.zeros(shape, dtype)

# Maybe interesting?
@typing.overload
def myfun(shape: int | typing.Sequence[int], dtype: jaxtyping.JaxIntTypeLike) -> jaxtyping.Integer[jaxtyping.Array, '...']:
    ...

@typing.overload
def myfun(shape: int | typing.Sequence[int], dtype: jaxtyping.JaxFloatTypeLike) -> jaxtyping.Float[jaxtyping.Array, '...']:
    ...

I saw that jax.typing.DTypeLike exists, but it is still private...

This would e.g., be useful for annotating neural network parameter initializers, which often depend on an explicit dtype.

Runtime type checking via `typeguard` causes `TypeError` due to array's having type `DeviceArray`.

I'm trying to use jaxtyping with runtime type checking via typeguard as described here. Here's my code:

import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from typeguard import typechecked as typechecker


@jaxtyped
@typechecker
def foo(
    x: Float[Array, "n"],
    y: Float[Array, "n"],
) -> Float[Array, "n"]:
    return x + y

print(foo(jnp.arange(10), jnp.arange(10)))

However when I run the above script, I get the following error:

Traceback (most recent call last):
  File "/Users/jay/playground/myscript.py", line 14, in <module>
    print(foo(jnp.arange(10), jnp.arange(10)))
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/jaxtyping/decorator.py", line 41, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "x" must be jaxtyping.Float[ndarray, 'n']; got jaxlib.xla_extension.DeviceArray instead

Steps to reproduce my python environment (Note: I'm running this on an M1 Macbook Pro with macOS Monterey 12.2 (21D49)):

$ python -V
Python 3.9.10

$ python -m venv .venv

$ source .venv/bin/activate

$ python -m pip install --upgrade pip

$ python -m pip install "jax[cpu]==0.3.17" "jaxtyping==0.2.5"

[Documentation] Comparison with `tensor_annotations`

GitHub just recommended @deepmind's JAX-friendly tensor_annotations typing API in my feed. I was fascinated to see something I'd never seen before: a tensor typing API covering not one but two tensor frameworks. Cue chilling Halloween face. ๐Ÿ˜ฑ

Truly, the golden era of One Typing API to Rule Them All is upon us.

Documentation: You Bedevil Us Yet Again

The lack of runtime type-checking renders tensor_annotations ignorable from the (...admittedly limited) perspective of @beartype. Still, it might be useful to publicly compare and contrast the differences between these two packages here at jaxtyping.

The high-level tl;dr might be:

  • jaxtyping covers only JAX but supports both runtime and static type-checking.
  • tensor_annotations covers both JAX and TensorFlow but supports only static type-checking.

Tensors: The Future Begins Tomorrow

This gently leads into an orthogonal โ€“ and much larger โ€“ topic that probably deserves its own issue, too. But we are lazy. So let's idly chat here instead.

JAX is a Google product. TensorFlow is a Google product. Both expose tensors to Python. jaxtyping currently only types JAX. Could jaxtyping theoretically be expanded to type TensorFlow as well? If so, victory.

Thanks so much for all the runtime type-checking support, @patrick-kidger and friends. You do the AI's good work.

jaxtyping jt.decorator.storace.memo_stack sometimes contains __builtin__ variables

Sometimes upon inspection of memo_stack i see that there is builtins key containing all builtins(), i think this is not the intended the behavior. I didn't dig too deep, but this most often happens if i have arithmetic dimensions such as batch+1.

Here is a repro:

https://colab.research.google.com/drive/17WYz1ZJuPz1NRomHYlUMDLQegBxpInnP#scrollTo=Q4qoDvog3FM0

import typeguard
import functools as ft
import jax.numpy as jnp
@typeguard.typechecked
def foo(i: jt.Int32[jt.Array, "batch"], j:jt.Int32[jt.Array, "batch+1"], ) -> jt.Int32[jt.Array, "batch"]:
  return i

@jt.jaxtyped
def bar():    
  try: 
    foo(jnp.array([1,2]), jnp.array([2, 2, 1]))
  finally:
      print(jt.decorator.storage.memo_stack)  

bar()

results in:

[({'batch': 2, '__builtins__': {'__name__': 'builtins', '__doc__': "Built-in functions, exceptions, and other objects.\n\nNoteworthy: None is the `nil' object; Ellipsis represents `...' in slices.", '__package__': '', '__loader__': <class '_frozen_importlib.BuiltinImporter'>, '__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'), '__build_class__': <built-in function __build_class__>, '__import__': <built-in function __import__>, 'abs': <built-in function abs>, 'all': <built-in function all>, 'any': <built-in function any>, 'ascii': <built-in function ascii>, 'bin': <built-in function bin>, 'breakpoint': <built-in function breakpoint>, 'callable': <built-in function callable>, 'chr': <built-in function chr>, 'compile': <built-in function compile>, 'delattr': <built-in function delattr>, 'dir': <built-in function dir>, 'divmod': <built-in function divmod>, 'eval': <built-in function eval>, 'exec': <built-in function exec>, 'format': <built-in function format>, 'getattr': <built-in function getattr>, 'globals': <built-in function globals>, 'hasattr': <built-in function hasattr>, 'hash': <built-in function hash>, 'hex': <built-in function hex>, 'id': <built-in function id>, 'input': <bound method Kernel.raw_input of <google.colab._kernel.Kernel object at 0x7f61db3c7cd0>>, 'isinstance': <built-in function isinstance>, 'issubclass': <built-in function issubclass>, 'iter': <built-in function iter>, 'aiter': <built-in function aiter>, 'len': <built-in function len>, 'locals': <built-in function locals>, 'max': <built-in function max>, 'min': <built-in function min>, 'next': <built-in function next>, 'anext': <built-in function anext>, 'oct': <built-in function oct>, 'ord': <built-in function ord>, 'pow': <built-in function pow>, 'print': <built-in function print>, 'repr': <built-in function repr>, 'round': <built-in function round>, 'setattr': <built-in function setattr>, 'sorted': <built-in function sorted>, 'sum': <built-in function sum>, 'vars': <built-in function vars>, 'None': None, 'Ellipsis': Ellipsis, 'NotImplemented': NotImplemented, 'False': False, 'True': True, 'bool': <class 'bool'>, 'memoryview': <class 'memoryview'>, 'bytearray': <class 'bytearray'>, 'bytes': <class 'bytes'>, 'classmethod': <class 'classmethod'>, 'complex': <class 'complex'>, 'dict': <class 'dict'>, 'enumerate': <class 'enumerate'>, 'filter': <class 'filter'>, 'float': <class 'float'>, 'frozenset': <class 'frozenset'>, 'property': <class 'property'>, 'int': <class 'int'>, 'list': <class 'list'>, 'map': <class 'map'>, 'object': <class 'object'>, 'range': <class 'range'>, 'reversed': <class 'reversed'>, 'set': <class 'set'>, 'slice': <class 'slice'>, 'staticmethod': <class 'staticmethod'>, 'str': <class 'str'>, 'super': <class 'super'>, 'tuple': <class 'tuple'>, 'type': <class 'type'>, 'zip': <class 'zip'>, '__debug__': True, 'BaseException': <class 'BaseException'>, 'Exception': <class 'Exception'>, 'TypeError': <class 'TypeError'>, 'StopAsyncIteration': <class 'StopAsyncIteration'>, 'StopIteration': <class 'StopIteration'>, 'GeneratorExit': <class 'GeneratorExit'>, 'SystemExit': <class 'SystemExit'>, 'KeyboardInterrupt': <class 'KeyboardInterrupt'>, 'ImportError': <class 'ImportError'>, 'ModuleNotFoundError': <class 'ModuleNotFoundError'>, 'OSError': <class 'OSError'>, 'EnvironmentError': <class 'OSError'>, 'IOError': <class 'OSError'>, 'EOFError': <class 'EOFError'>, 'RuntimeError': <class 'RuntimeError'>, 'RecursionError': <class 'RecursionError'>, 'NotImplementedError': <class 'NotImplementedError'>, 'NameError': <class 'NameError'>, 'UnboundLocalError': <class 'UnboundLocalError'>, 'AttributeError': <class 'AttributeError'>, 'SyntaxError': <class 'SyntaxError'>, 'IndentationError': <class 'IndentationError'>, 'TabError': <class 'TabError'>, 'LookupError': <class 'LookupError'>, 'IndexError': <class 'IndexError'>, 'KeyError': <class 'KeyError'>, 'ValueError': <class 'ValueError'>, 'UnicodeError': <class 'UnicodeError'>, 'UnicodeEncodeError': <class 'UnicodeEncodeError'>, 'UnicodeDecodeError': <class 'UnicodeDecodeError'>, 'UnicodeTranslateError': <class 'UnicodeTranslateError'>, 'AssertionError': <class 'AssertionError'>, 'ArithmeticError': <class 'ArithmeticError'>, 'FloatingPointError': <class 'FloatingPointError'>, 'OverflowError': <class 'OverflowError'>, 'ZeroDivisionError': <class 'ZeroDivisionError'>, 'SystemError': <class 'SystemError'>, 'ReferenceError': <class 'ReferenceError'>, 'MemoryError': <class 'MemoryError'>, 'BufferError': <class 'BufferError'>, 'Warning': <class 'Warning'>, 'UserWarning': <class 'UserWarning'>, 'EncodingWarning': <class 'EncodingWarning'>, 'DeprecationWarning': <class 'DeprecationWarning'>, 'PendingDeprecationWarning': <class 'PendingDeprecationWarning'>, 'SyntaxWarning': <class 'SyntaxWarning'>, 'RuntimeWarning': <class 'RuntimeWarning'>, 'FutureWarning': <class 'FutureWarning'>, 'ImportWarning': <class 'ImportWarning'>, 'UnicodeWarning': <class 'UnicodeWarning'>, 'BytesWarning': <class 'BytesWarning'>, 'ResourceWarning': <class 'ResourceWarning'>, 'ConnectionError': <class 'ConnectionError'>, 'BlockingIOError': <class 'BlockingIOError'>, 'BrokenPipeError': <class 'BrokenPipeError'>, 'ChildProcessError': <class 'ChildProcessError'>, 'ConnectionAbortedError': <class 'ConnectionAbortedError'>, 'ConnectionRefusedError': <class 'ConnectionRefusedError'>, 'ConnectionResetError': <class 'ConnectionResetError'>, 'FileExistsError': <class 'FileExistsError'>, 'FileNotFoundError': <class 'FileNotFoundError'>, 'IsADirectoryError': <class 'IsADirectoryError'>, 'NotADirectoryError': <class 'NotADirectoryError'>, 'InterruptedError': <class 'InterruptedError'>, 'PermissionError': <class 'PermissionError'>, 'ProcessLookupError': <class 'ProcessLookupError'>, 'TimeoutError': <class 'TimeoutError'>, 'open': <built-in function open>, 'copyright': Copyright (c) 2001-2023 Python Software Foundation.
All Rights Reserved.

Copyright (c) 2000 BeOpen.com.
All Rights Reserved.

Copyright (c) 1995-2001 Corporation for National Research Initiatives.
All Rights Reserved.

Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam.
All Rights Reserved., 'credits':     Thanks to CWI, CNRI, BeOpen.com, Zope Corporation and a cast of thousands
    for supporting Python development.  See [www.python.org](http://www.python.org/) for more information., 'license': Type license() to see the full license text, 'help': Type help() for interactive help, or help(object) for help about object., '__IPYTHON__': True, 'display': <function display at 0x7f61e2b089d0>, 'execfile': <function execfile at 0x7f61ca4dd240>, 'runfile': <function runfile at 0x7f61ca31add0>, '__pybind11_internals_v4_gcc_libstdcpp_cxxabi1013__': <capsule object NULL at 0x7f61ca1c1560>, 'get_ipython': <bound method InteractiveShell.get_ipython of <google.colab._shell.Shell object at 0x7f61db3c7d60>>}}, {}, {})]

[RFC] Syntax changes

There's a planned rewrite to change the syntax for jaxtyping.

This is available here:
https://github.com/google/jaxtyping/tree/rewrite

In short, the changes are:

  • jaxtyping.Float rather than just jaxtyping.f to denote precision-independent types. (Means we don't have names like i and f that are commonly-used variable names, and a bit opaque as to what they actually mean.)
    • Likewise
       i -> IntSign
       u - > IntUnsign
       t - > Int
       x -> Inexact
       c -> Complex
       n -> Num
      
  • Changing from Float["batch length channels"] to Float[jnp.ndarray, "batch length channels"].
    • If you're not specifying the dtype at all, e.g. jaxtyping.Array["foo bar"], then this is now jaxtyping.Shaped[jnp.ndarray, "foo bar"].
    • For the sake of neat syntax we now have jaxtyping.Array = jnp.ndarray, so that you can use the nicer-looking Array instead of jnp.ndarray, if you wish.

Regarding this latter change:
Pros:

  • Partial compatibility with static type checking. (Hurrah!) Float[jnp.ndarray, "foo"] will now smoothly fall back to being treated as just jnp.ndarray by static type checkers, instead of just being hopelessly incompatible.
  • The ability to specify non-JAX-array types! Including NumPy/TensorFlow/PyTorch.
  • Compatibility with the upcoming jax.typing namespace. Which has more limited aims of static type checking support; the plan is for jaxtyping to be a superset of jax.typing.

Cons:

  • A little more verbose.

Particular questions I'd welcome feedback on:

  • Alternate project names? Having both jax.typing and jaxtyping is a bit confusing. (But perhaps not too bad, if jaxtyping is a strict superset.)
  • Alternate names for IntSign and IntUnsign? (SInt? IntS?)
  • Alternate names for Float, e.g. just F?
  • Alternate names for f32, e.g. Float32? (At present these are still the short ones.)
  • Any strong opinions about the Float["foo"] -> Float[Array, "foo"] change mentioned above? (I've already chatted with a few Google-internal people who have strong opinions both for and against this, haha.)

CC @thomaspinder @heytanay @daniel-dodd @ayaka14732 as the folks GitHub currently lists as using this project in public repos.

TorchScript compatibility

Hi,

The package now does not require JAX but can also work with torch.Tensor
I was trying this, and it worked fine, but it fails when using torch.jit.script
In TorchTyping you fixed it with patrick-kidger/torchtyping#13
Is there something similar here?

Error message could include more shape information

The following code passes typechecking, and runs without error

import jax
from typeguard import typechecked as typechecker
from jaxtyping import f32, u, jaxtyped

@jaxtyped
@typechecker
def standardize(x : f32["N"], eps=1e-5):
    return (x - x.mean()) / (x.std() + eps)

rng = jax.random.PRNGKey(42)

embeddings = jax.random.uniform(rng, (11,))
t1 = standardize(embeddings)

The following code currectly fails typechecking, but the message would ideally tell us why the shapes don't match

embeddings = jax.random.uniform(rng, (11,13))
t1 = standardize(embeddings)
# TypeError: type of argument "x" must be jaxtyping.array_types.f32['N']; got jaxlib.xla_extension.DeviceArray instead

This would more ideally be something like

# TypeError: type of argument "x" must be jaxtyping.array_types.f32['N']; got jaxlib.xla_extension.DeviceArray(dtype=float32,shape=(11,13)) instead

Wrong error message when using jaxtyping with equinox

Thanks for creating a lot of amazing libraries on JAX ecosystem!

As an exploration, I am trying out equinox and wanna annotate a loss function by jaxtyping
While the runtime type check manages to warn the wrong tensor shape of my input, it wrongly flags the argument that causes the error.

Here is a snippet of my code to highlight the issue (refer to my colab notebook below for the complete version):

@jaxtyped
@typechecked
@jax.jit
@jax.grad
def loss_fn(
    model: Linear,
    x: Float[Array, "batch in_dim"],
    y: Float[Array, "batch out_dim"]
) -> Linear:
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
wrong_x = jax.numpy.zeros((50, in_size))
grads = loss_fn(model, x=wrong_x, y=y)

The error message should raise that argument x is wrong, but here is the message I received:

TypeError                                 Traceback (most recent call last)
[<ipython-input-9-6d78a92b3a76>](https://localhost:8080/#) in <module>
      1 wrong_x = jax.numpy.zeros((50, in_size))
----> 2 grads = loss_fn(model, x=wrong_x, y=y)

2 frames
[/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py](https://localhost:8080/#) in check_argument_types(memo)
    873                 check_type(description, value, expected_type, memo)
    874             except TypeError as exc:  # suppress unnecessarily long tracebacks
--> 875                 raise TypeError(*exc.args) from None
    876 
    877     return True

TypeError: type of argument "y" must be jaxtyping.array_types.Float[ndarray, 'batch out_dim']; got jaxlib.xla_extension.DeviceArray instead

Package
JAX version: 0.3.17
eqx version: 0.7.1
jaxtyping: 0.2.0
typeguard as runtime type checking

Simple colab notebook for reproducing the bug
https://colab.research.google.com/drive/10rSs6IhNmU7lvhPxxlU2ext8JkgBLULI?usp=sharing

(As a side note, I am not sure if I am in good practice to annotate the model by its class, feel free to comment)

Operator support

This is not a comprehensive list, but many of the arithmetic operator doesn't seem fully supported yet.

Example:

Pyright: Operator "-" not supported for types "f" and "f"

import jax.numpy as jnp
import jaxtyping

a: jaxtyping.f = jnp.zeros(3)
b: jaxtyping.f = jnp.ones(3)

c = a - b  # Pyright: Operator "-" not supported for types "f" and "f"
print(c)  # prints [-1 -1 -1], a valid code

pyright 1.1.269 (latest, 2022-09-03)

[Question] Is there a way to annotate equal shape for input and return?

Thank you for this library.

I have a function that accepts a tensor and should return a tensor of equal shape. The input shape can be anything, but the output shape should always be the same. Is there a good way to annotate this with the library?

I read the shape API and settled on this:

def f(x: Float[torch.Tensor, "*"]) -> Float[torch.Tensor, "*"]:
    ....

but I believe this would allow f to return a tensor of any shape, not necessarily one that matches x.

Cannot import name 'Key' from jaxtyping.array_types

Hey Patrick,

Just had some tests fail from Jaxtyping 2.17 (logs

Serves me right for not pinning a version but thought I'd flag it for you :)

____________ ERROR collecting tests/protein/tensor/test_sequence.py ____________
ImportError while importing test module '/home/runner/work/graphein/graphein/tests/protein/tensor/test_sequence.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/usr/share/miniconda3/envs/test/lib/python3.8/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
tests/protein/tensor/test_sequence.py:12: in <module>
    from graphein.protein.tensor.io import protein_df_to_tensor
graphein/protein/tensor/__init__.py:9: in <module>
    from .data import Protein
graphein/protein/tensor/data.py:26: in <module>
    from .angles import (
graphein/protein/tensor/angles.py:16: in <module>
    from .testing import has_nan
graphein/protein/tensor/testing.py:13: in <module>
    from .sequence import get_atom_indices
graphein/protein/tensor/sequence.py:21: in <module>
    from .types import AtomTensor
graphein/protein/tensor/types.py:12: in <module>
    from jaxtyping import Float, Int
/usr/share/miniconda3/envs/test/lib/python3.8/site-packages/jaxtyping/__init__.py:101: in <module>
    from .array_types import (
E   ImportError: cannot import name 'Key' from 'jaxtyping.array_types' (/usr/share/miniconda3/envs/test/lib/python3.8/site-packages/jaxtyping/array_types.py)

Allow `#` to broadcast over scalars

Currently, broadcasting functions can be annotated using #. From the API docs:

  • * name: zero or more variable-size axes, e.g. f32["*batch c h w"]
  • Append # to a dimension size to indicate that it can be that size or equal to one -- i.e. broadcasting is acceptable.

However functions annotated in this way don't support broadcasting over scalar inputs, even though JAX does.

I would welcome if f["n#"] (or f["*n#"]) allowed broadcasting over scalar inputs instead of having to use verbose union types.

MWE

JAX' jnp.add will happily broadcast all of these inputs:

import jax.numpy as jnp
from jaxtyping import f, jaxtyped
from typeguard import typechecked

x_float = jnp.float32(0.1)
y_float = jnp.float32(0.2)

x_array = jnp.array([x_float])
y_array = jnp.array([y_float])

xs = jnp.array([0.1, 0.2])
ys = jnp.array([0.2, 0.3])
>>> jnp.add(xs, ys)
DeviceArray([0.3, 0.5], dtype=float32)

>>> jnp.add(x_array, y_array)
DeviceArray([0.3], dtype=float32)

>>> jnp.add(x_float, y_float)
DeviceArray(0.3, dtype=float32)

>>> jnp.add(xs, x_array)
DeviceArray([0.2, 0.3], dtype=float32)

>>> jnp.add(xs, y_float)
DeviceArray([0.3, 0.4], dtype=float32)

>>> jnp.add(x_array, x_float)
DeviceArray([0.2], dtype=float32)

However annotating a wrapper functions myadd for scalar/vector addition will throw an error on any operation involving scalar inputs:

@jaxtyped
@typechecked
def myadd(x: f["n#"], y: f["n#"]) -> f["n#"]:
    return jnp.add(x, y)
>>> myadd(xs, ys)
DeviceArray([0.3, 0.5], dtype=float32)

>>> myadd(x_array, y_array)
DeviceArray([0.3], dtype=float32)

>>> myadd(x_float, y_float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "x" must be jaxtyping.array_types.f['n#']; got jaxlib.xla_extension.DeviceArray instead

>>> myadd(xs, x_array)
DeviceArray([0.2, 0.3], dtype=float32)

>>> myadd(xs, y_float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "y" must be jaxtyping.array_types.f['n#']; got jaxlib.xla_extension.DeviceArray instead

>>> myadd(x_array, y_float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "y" must be jaxtyping.array_types.f['n#']; got jaxlib.xla_extension.DeviceArray instead

Adding * for zero-dims, we get f["*n#"], which makes myadd2(x_float, y_float) pass:

@jaxtyped
@typechecked
def myadd2(x: f["*n#"], y: f["*n#"]) -> f["*n#"]:
    return jnp.add(x, y)
>>> myadd2(xs, ys)
DeviceArray([0.3, 0.5], dtype=float32)

>>> myadd2(x_array, y_array)
DeviceArray([0.3], dtype=float32)

>>> myadd2(x_float, y_float)
DeviceArray(0.3, dtype=float32)

>>> myadd2(xs, x_array)
DeviceArray([0.2, 0.3], dtype=float32)

>>> myadd2(xs, y_float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "y" must be jaxtyping.array_types.f['*n#']; got jaxlib.xla_extension.DeviceArray instead

>>> myadd2(x_array, y_float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "y" must be jaxtyping.array_types.f['*n#']; got jaxlib.xla_extension.DeviceArray instead

Using just * for zero-dims, myadd2(x_float, y_float) passes, but myadd3(xs, x_array) fails as we might expect :

@jaxtyped
@typechecked
def myadd3(x: f["*n"], y: f["*n"]) -> f["*n"]:
    return jnp.add(x, y)
>>> myadd3(xs, ys)
DeviceArray([0.3, 0.5], dtype=float32)

>>> myadd3(x_array, y_array)
DeviceArray([0.3], dtype=float32)

>>> myadd3(x_float, y_float)
DeviceArray(0.3, dtype=float32)

>>> myadd3(xs, x_array)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "y" must be jaxtyping.array_types.f['*n']; got jaxlib.xla_extension.DeviceArray instead

>>> myadd3(xs, y_float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "y" must be jaxtyping.array_types.f['*n']; got jaxlib.xla_extension.DeviceArray instead

>>> myadd3(x_array, y_float)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/jaxtyping/decorator.py", line 34, in wrapper
    return fn(*args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/jax-md-playground-UowHBS0s-py3.10/lib/python3.10/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "y" must be jaxtyping.array_types.f['*n']; got jaxlib.xla_extension.DeviceArray instead

In this specific case, this can worked around with union types, however these quickly get very verbose:

from typing import Union

@jaxtyped
@typechecked
def myadd4(x: Union[f[""], f["n#"]], y: Union[f[""], f["n#"]]) -> Union[f[""], f["n#"]]:
    return jnp.add(x, y)
>>> myadd4(xs, ys)
DeviceArray([0.3, 0.5], dtype=float32)

>>> myadd4(x_array, y_array)
DeviceArray([0.3], dtype=float32)

>>> myadd4(x_float, y_float)
DeviceArray(0.3, dtype=float32)

>>> myadd4(xs, x_array)
DeviceArray([0.2, 0.3], dtype=float32)

>>> myadd4(xs, y_float)
DeviceArray([0.3, 0.4], dtype=float32)

>>> myadd4(x_array, y_float)
DeviceArray([0.3], dtype=float32)

Can we check statement-level annotations?

One of my dreams for this package was to turn code like this

  query = linear(head.query, t1)                  # L x Dk
  key = linear(head.key, t1)                      # L x Dk

into this

  query : LxDk = linear(head.query, t1)
  key   : LxDk = linear(head.key, t1)

where we have written

  LxDk = f32["L Dk"] 

earlier in the @jaxtyped function.

But it looks as if these annotations aren't checked?
I haven't looked into how hard that might be - is it a lot of work?

`jaxtyped(typechecked(..))` doesn't seem to work with generators

Consider the following PyTorch repro:

import jaxtyping
import torch
import typeguard
from typing import Iterator

@jaxtyping.jaxtyped
@typeguard.typechecked
def g(x: jaxtyping.Shaped[torch.Tensor, '*']) -> Iterator[jaxtyping.Shaped[torch.Tensor, '*']]:
    yield x

@jaxtyping.jaxtyped
def f():
    next(g(torch.zeros(1)))
    next(g(torch.zeros(2)))

f()

This yields TypeError: type of value yielded from generator must be jaxtyping.Shaped[Tensor, '*']; got torch.Tensor instead, due to this condition failing with torch.Size([1]) != torch.Size([2]).

jaxtyping not working with threads

Typecheck cannot be performed in another thread:

import jax.numpy as np
from jaxtyping import Array, Float as F, jaxtyped
import threading
from typeguard import typechecked as typechecker

@jaxtyped
@typechecker
def add(x: F[Array, 'a b'], y: F[Array, 'a b']) -> F[Array, 'a b']:
    return x + y

def run():
    a = np.array([[1., 2.]])
    b = np.array([[2., 3.]])
    c = add(a, b)
    print(c)

thread = threading.Thread(target=run)
thread.start()
thread.join()

Error:

Exception in thread Thread-1 (run):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ayaka/Projects/test/1.py", line 14, in run
    c = add(a, b)
  File "/home/ayaka/Projects/test/venv/lib/python3.10/site-packages/jaxtyping/decorator.py", line 31, in wrapper
    storage.memo_stack.append(({}, {}, {}))
AttributeError: '_thread._local' object has no attribute 'memo_stack'

Environment:

  • Python 3.10.6
  • jax 0.3.17
  • jaxlib 0.3.15+cuda11.cudnn82
  • jaxtyping 0.2.1

Is it possible to create a parameterized type annotation?

Apologies if this is a naive question.

I have a project where I will reuse the same array spec many times, but in some cases (e.g. modelling) they will be Jax arrays, and in other cases (e.g. preprocessing; analysis) they will be NumPy arrays.

Ideally I'd like to specify the dtype and the shape once, and then pass the array type at each declaration:

Data = Float[<something here>, "batch a b c"]
x: Data[Array] = ...
y: Data[np.ndarray] = ...

My current attempt looks like this:

import numpy as np
from jaxtyping import Array as JaxArray, Float 

Jax = JaxArray
NumPy = np.ndarray
ArrayT = TypeVar("ArrayT", Jax, NumPy)

Data = Float[ArrayT, "batch a b c"]

But this fails with:

E   TypeError: '_MetaAbstractArray' object is not subscriptable

Is it possible to achieve this? If so, what am I doing wrong?

Thank you!

Would/should `isinstance(x,i)` make sense for scalars?

I might want to write

@jax.jit
@jaxtyped
def foo1(x: int, t: f32["N"]):
    y: int = x + 3
    assert isinstance(y, int), f"y is a {type(y)}!"
    ...

But that won't work with JIT:

AssertionError: y is a <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>!

I thought "Aha! I'll just use a jaxtyping.i, that will understand the situation"

@jax.jit
@jaxtyped
def foo1(x: int, t: f32["N"]):
    y: int = x + 3
    assert isinstance(y, jaxtyping.i), f"y is a {type(y)}!"
    ...

But that says

RuntimeError: Do not use `isinstance(x, jaxtyping.i`. If you want to check just the dtype of an array, then use `jaxtyping.i["..."]`.

Is there a way to get this check aware of tracers?

[Edit: Doh, I see that i[""] passes the check. It feels a little verbose, but works fine.]

Unclear errors from beartype/typeguard (possibly misused!)

Hi Patrick! Using jaxtyping for the first time, and ran into some peculiar error messages for a basic RNN example when using beartype/typeguard.

I'll paste the code first, and then the errors:

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker  
# from typeguard import typechecked as typechecker  # uncomment to get typeguard error
from equinox import Module as JaxClass


@jaxtyped
@typechecker
class Parameters(JaxClass):
    embedding_weights: Float[Array, "hidden_state embedding"]
    hidden_state_weights: Float[Array, "hidden_state hidden_state"]
    output_weights: Float[Array, "embedding hidden_state"]
    hidden_state_bias: Float[Array, "hidden_state"]
    output_bias: Float[Array, "embedding"]

e = 12
h = 10

init_pars = Parameters(jnp.ones((h,e)), jnp.ones((h,h)), jnp.ones((e,h)), jnp.ones((h,)), jnp.ones((e,)))
init_state = jnp.ones((h,))
random_embed = jnp.zeros((e,)).at[3].set(1)

@jax.jit
@jaxtyped
@typechecker
def update_hidden_state(
    embedding: Float[Array, "embedding"], 
    hidden_state: Float[Array, "hidden_state"], 
    params: Parameters
) -> Float[Array, "embedding"]:
    return params.hidden_state_weights @ hidden_state + params.embedding_weights @ embedding + params.hidden_state_bias

update_hidden_state(random_embed, init_state, init_pars)

First, beartype: the runtime checking behaves as expected when the shapes match. If, however, I return an array of shape "embedding" in update_hidden_state, I correctly get a complaint, but it's a bit obscure:

BeartypeCallHintReturnViolation: Function __main__.update_hidden_state() return "Tracedwith" violates type hint , as  "Tracedwith" not instance of .

which makes it seem like some function name is not being propagated or something (also fairly new to runtime type checking, but I would assume the error should be a little clearer!)

If I uncomment the typeguard line, it's even less clear what I'm doing wrong:

TypeError:  is a built-in module

which again, seems to be missing the name of something? Hoping this is just me not using something correctly :p

This was replicable on python 3.8 and 3.11 virtual envs.

Versions:
typeguard==3.0.2
jax==0.4.8
jaxlib==0.4.7
jaxtyping==0.2.15
equinox==0.10.3
beartype==0.14.0

jaxtyping does not play nicely with inheritance

The code below fails, whether I use typeguard or beartype,
and gives the error

    class LinearGaussianConjugateSSM(LinearGaussianSSM):
TypeError: __init__() takes 2 positional arguments but 4 were given

However if I omit the initial @jaxtyped it works.

from jaxtyping import jaxtyped
#from beartype import beartype as typechecker
from typeguard import typechecked as typechecker


@jaxtyped ### OMIT
@typechecker
class LinearGaussianSSM():
    def __init__(
        self,
        state_dim: int,
        emission_dim: int,
        input_dim: int=0,
        has_dynamics_bias: bool=True,
        has_emissions_bias: bool=True
    ):
        self.state_dim = state_dim
        self.emission_dim = emission_dim
        self.input_dim = input_dim
        self.has_dynamics_bias = has_dynamics_bias
        self.has_emissions_bias = has_emissions_bias


  
@typechecker
class LinearGaussianConjugateSSM(LinearGaussianSSM):
    def __init__(self,
                 state_dim,
                 emission_dim,
                 input_dim=0,
                 has_dynamics_bias=True,
                 has_emissions_bias=True):
        super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim,
        has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias)

@slinderman

[Question] Type aliases

Is it possible to define type aliases for commonly used types?

e.g. f["a b c"] for Float[torch.Tensor, "a b c"]

[Question] Partial shape initialization

I am a newbie with jaxtyping; I am wondering if it is possible to have partial initialization of shapes or nested definitions.

In my case, I have a shape for a sample of data that I would like to define once:

MyDataType = Float[Array, "chan1 chan2 chan3 features"]

And then, I would like to reuse MyDataType definition multiple times to define batches of different lengths.
I tried something like:

Batch = Float[MyDataType, "batch_size"]

but it doesn't see batch_size as a new axis.

An alternative useful feature for doing this would be a partial initialization like:

Batch = Float[Array, "$1 chan1 chan2 chan3 features"]

Batch["n"] == Float[Array, "n chan1 chan2 chan3 features"]

A workaround like this works:

class _MetaBatch(type):
    def __getitem__(cls, s: str):
        return Float[Array, f"{s} chan1 chan2 chan3 features"]

class Batch(metaclass=_MetaBatch):
  pass

Is there a more elegant way of doing this?

Strange interaction between Shaped and Union

The jaxtyping.Shaped / Float / Int etc. interact strangely with typing.Union:

from typing import Union
from jaxtyping import Array, Shaped
import jax.numpy as jnp

x = jnp.zeros([3])

# These all work:
assert isinstance(x, Array)
assert isinstance(x, Union[Array, int])
assert isinstance(x, Shaped[Array, "_"])
# But this one fails:
assert isinstance(x, Union[Shaped[Array, "_"], int])  # <<  AssertionError

# interestingly this one works:
assert isinstance(x, Shaped[Array, "_"] | int)

My usecase was to define an alias that accepts both jax.Array and np.ndarray, which I first tried without luck like so:

assert isinstance(x, Shaped[Union[Array, np.ndarray], "_"])  # << AssertionError
assert isinstance(x, Shaped[Array | np.ndarray, "_"])
# TypeError: type 'types.UnionType' is not an acceptable base type

For now I can use Shaped[Array, "_"] | Shaped[np.ndarray, "_"], but this behavior was very surprising to me and seems like a bug.

Why not `TypeVarTuple`?

TypeVarTuple is about to be introduced in Python 3.11, but it's already usable with typing_extensions.

Docs: https://docs.python.org/3.11/library/typing.html#typing.TypeVarTuple

Example:

from typing import Generic, Literal, TypeVar
from typing_extensions import TypeVarTuple, Unpack

dims = TypeVarTuple("dims")
dtype = TypeVar("dtype")
dim1 = Literal["dim1"]
dim2 = Literal["dim2"]
dim3 = Literal["dim3"]
batch = Literal["batch"]
c1 = Literal["c1"]
c2 = Literal["c2"]

class f32(Generic[Unpack[dims]]): ...
class PyTree(Generic[dtype]): ...

def matrix_multiply(x: f32[dim1, dim2], y: f32[dim2, dim3]) -> f32[dim1, dim3]:
    ...

def accepts_pytree_of_ints(x: PyTree[int]):
    ...

def accepts_pytree_of_arrays(x: PyTree[f32[batch, None, c2]]):
    ...

pytest plugin broken in >=v0.2.16

It seems in versions >=v0.2.16 the pytest plugin no longer works propery:

pytest: error: unrecognized arguments: --jaxtyping-packages=my_package,beartype.beartype

Works fine in v0.2.15.

View full types in VSCode on-hover type hints

When using jaxtyping, I noticed that, even if a function's argument is typed as e.g. Float[Tensor, "b h"], when hovering over the function in VSCode, its argument is typed only as Tensor. Is there some way to get VSCode to display the full type of the function on hover?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.