Comments (3)
Hmm, this looks like an upstream bug in JAX. I'd suggest opening an issue over there.
Here's a minimal repro without using jaxtyping:
import beartype
import jax
# Fails at static type-checking time
def f(x: jax.Array):
pass
f(jax.random.PRNGKey(0))
# pyright says:
# Argument of type "KeyArray" cannot be assigned to parameter "x" of type "Array" in function "f"
# Fails at runtime
@beartype.beartype
def g(x: jax.random.KeyArray):
pass
g(jax.random.PRNGKey(0))
# beartype.roar.BeartypeCallHintParamViolation: @beartyped __main__.g() parameter x="Array([0, 0], dtype=uint32)" violates type hint <class 'jax._src.prng.PRNGKeyArray'>, as <protocol "jaxlib.xla_extension.ArrayImpl"> "Array([0, 0], dtype=uint32)" not instance of <class "jax._src.prng.PRNGKeyArray">
The problem being (as you've found) that neither annotation is correct in all cases.
from jaxtyping.
Thank you! It seems like the issue is already there in the JAX repo:
#google/jax#12706
and indeed it's recommended to use type union for now.
from jaxtyping.
For anyone else coming across this: use jaxtyping.PRNGKeyArray
and you should be good-to-go.
from jaxtyping.
Related Issues (20)
- Feature request: Remove entry "modules" in function "install_import _hook()" HOT 2
- Venv `__pycache__` directories filling up HOT 1
- Failed to compile with Union HOT 3
- Type annotations must now include an explicit array type HOT 5
- Simple script error HOT 2
- A bug with typechecking fields of dataclasses with default values/factories HOT 14
- Support for jax.dtypes.prng_key, to denote jax.Arrays of PRNG keys, as in [JEP 9263](https://jax.rtfd.io/en/latest/jep/9263-typed-keys.html). HOT 2
- test_import_hook_transitive is flaky! HOT 2
- Order of symbolic expression evaluation HOT 2
- Module-level import hook HOT 1
- Support runtime type-checking of generic functions HOT 1
- Random key does not typecheck Key[Scalar, ""] HOT 2
- Weird `KeyError: '0'` when using None typechecker for `install_import_hook` HOT 4
- ImportError: cannot import name 'Array' from 'jaxtyping' HOT 2
- Symbolic expressions in argument annotations HOT 3
- Old-style decoration fails to raise on dataclasses since 0.2.24 HOT 1
- will runtime type checking go beyond function parameters and return type? HOT 9
- [DOC] Need better documentation about `from __future__ import annotations` HOT 3
- How can I inspect the jaxtyping bindings? HOT 2
- IPython `inspect.getsource()` failure due to incorrect co_firstlineno HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.