Comments (8)
Ran into this as well for pytorch. For me the solution as described on https://docs.kidger.site/jaxtyping/api/array/#array was to use torch.Tensor
in place of jaxtyping.Array
, like so:
from torch import Tensor
from jaxtyping import Float32
def f(x: Float32[Tensor, "dim1 dim2"]) -> Float32[Tensor, "dim1 dim2"]:
return x
from jaxtyping.
You need to have JAX installed as well.
jaxtyping only has JAX as an optional dependency, to support also being used with PyTorch etc.
from jaxtyping.
@MilesCranmer thanks. @patrick-kidger the jax
requirement was relaxed? I don't see it anymore in pyproject.toml
.
If so i'll modify my code according to the syntax suggested by Miles
from jaxtyping.
Sorry, missed this question. Yes, jaxtyping no longer depends on JAX. The name is now for historical reasons only! The syntax Miles is using is correct.
from jaxtyping.
When authoring ML runtime agnostic tooling, such as a dataset, what is the correct array type to use? I cannot assume I have neither torch, jax nor tensorflow. I currently assume at least numpy and do the following, but it might not work for other use cases:
from typing import Union, TYPE_CHECKING
from jaxtyping import Float, Bool
if TYPE_CHECKING:
from torch import Tensor
from numpy import ndarray
from jaxtyping import Array as JaxArray
# TODO: tensorflow
Array = Union[Tensor, ndarray, JaxArray]
else:
from numpy import ndarray as Array
from jaxtyping.
Probably something like this:
from typing import Union, TYPE_CHECKING
if TYPE_CHECKING:
from torch import TorchTensor
from numpy import ndarray
from jaxtyping import Array as JaxArray
from tensorflow import TfTensor
Array = Union[TorchTensor, ndarray, JaxArray, TfTensor]
else:
arrays = []
try:
from torch import Tensor as TorchTensor
except Exception:
pass
else:
arrays.append(TorchTensor)
try:
from numpy import ndarray
except Exception:
pass
else:
arrays.append(ndarray)
try:
from jaxtyping import Array as JaxArray
except Exception:
pass
else:
arrays.append(JaxArray)
try:
from tensorflow import Tensor as TfTensor
except Exception:
pass
else:
arrays.append(TfTensor)
Array = Union[tuple(arrays)]
from jaxtyping.
Neat! I'd go for except (ModuleNotFoundError, ImportError):
😉
And it doesn't exactly roll of the tongue. Any chance this could be added to jaxtyping
?
from jaxtyping.
Actually, the more general Exception
is deliberate. There are cases when try-importing a module can result in other issues too, c.f. https://github.com/google/jaxtyping/blob/7a84b27da9e57c425ce4e6333121c3cdf2e07302/jaxtyping/_array_types.py#L33-L39
As for adding the above to jaxtyping. jaxtyping tries to essentially be backend-agnostic. In particular, I don't think I'd want to hardcode that it'll look for specifically torch+numpy+tensorflow+jax and nothing else. As such I think something like this is out-of-scope for jaxtyping, I'm afraid.
from jaxtyping.
Related Issues (20)
- Feature request: Remove entry "modules" in function "install_import _hook()" HOT 2
- Venv `__pycache__` directories filling up HOT 1
- Failed to compile with Union HOT 3
- Type annotations must now include an explicit array type HOT 5
- Simple script error HOT 2
- A bug with typechecking fields of dataclasses with default values/factories HOT 14
- Support for jax.dtypes.prng_key, to denote jax.Arrays of PRNG keys, as in [JEP 9263](https://jax.rtfd.io/en/latest/jep/9263-typed-keys.html). HOT 2
- test_import_hook_transitive is flaky! HOT 2
- Order of symbolic expression evaluation HOT 2
- Module-level import hook HOT 1
- Support runtime type-checking of generic functions HOT 1
- Random key does not typecheck Key[Scalar, ""] HOT 2
- Weird `KeyError: '0'` when using None typechecker for `install_import_hook` HOT 4
- ImportError: cannot import name 'Array' from 'jaxtyping' HOT 2
- Symbolic expressions in argument annotations HOT 3
- Old-style decoration fails to raise on dataclasses since 0.2.24 HOT 1
- will runtime type checking go beyond function parameters and return type? HOT 9
- [DOC] Need better documentation about `from __future__ import annotations` HOT 3
- How can I inspect the jaxtyping bindings? HOT 2
- IPython `inspect.getsource()` failure due to incorrect co_firstlineno HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.