Comments (2)
I think the __slots__
is unnecessary on a class that can never be instantiated. :)
But otherwise @mivanit's solution is a good one! I usually recommend writing something like
class f:
def __class_getitem__(cls, item):
return Float[torch.Tensor, item]
which is essentially the same thing as @mivanit's solution, but without the bells-and-whistles like docstrings and no-instantiation etc.
Note that this won't work with static type checkers, though. (Support for static type checking is the reason that jaxtyping works the way that it does.) If you do want to support them too then typically the best you can do is something like
from torch import Tensor as loat
from jaxtyping import Float as F
F[loat, "b w h"]
...wihch is obviously a huge hack to work around the heavy limitations of static type checkers.
from jaxtyping.
Here's what I use, no idea if it's best practice. It works with beartype for me, although I've done only pretty limited tests
import typing
import jaxtyping
import torch
import numpy as np
def jaxtype_factory(
name: str,
jax_dtype: type,
array_type: type = jaxtyping.Float,
) -> type:
class _BaseArray:
"""jaxtyping shorthand
jax_dtype = {jax_dtype}
array_type = {array_type}
"""
__slots__ = ()
def __new__(cls, *args, **kwargs):
raise TypeError("Type FArray cannot be instantiated.")
def __init_subclass__(cls, *args, **kwargs):
raise TypeError(f"Cannot subclass {cls.__name__}")
@typing._tp_cache
def __class_getitem__(cls, params):
if isinstance(params, str):
return array_type[jax_dtype, params]
else:
raise Exception(f"unexpected type for params:\n{type(params) = }\n{params = }")
_BaseArray.__name__ = name
_BaseArray.__doc__ = _BaseArray.__doc__.format(
jax_dtype=repr(jax_dtype),
array_type=repr(array_type),
)
return _BaseArray
# this makes linters happy
class F_Tensor(torch.Tensor):
@typing._tp_cache
def __class_getitem__(cls, params):
raise NotImplementedError()
F_Tensor = jaxtype_factory("F_Tensor", torch.Tensor, jaxtyping.Float)
from jaxtyping.
Related Issues (20)
- Feature request: Remove entry "modules" in function "install_import _hook()" HOT 2
- Venv `__pycache__` directories filling up HOT 1
- Failed to compile with Union HOT 3
- Type annotations must now include an explicit array type HOT 5
- Simple script error HOT 2
- A bug with typechecking fields of dataclasses with default values/factories HOT 14
- Support for jax.dtypes.prng_key, to denote jax.Arrays of PRNG keys, as in [JEP 9263](https://jax.rtfd.io/en/latest/jep/9263-typed-keys.html). HOT 2
- test_import_hook_transitive is flaky! HOT 2
- Order of symbolic expression evaluation HOT 2
- Module-level import hook HOT 1
- Support runtime type-checking of generic functions HOT 1
- Random key does not typecheck Key[Scalar, ""] HOT 2
- Weird `KeyError: '0'` when using None typechecker for `install_import_hook` HOT 4
- ImportError: cannot import name 'Array' from 'jaxtyping' HOT 2
- Symbolic expressions in argument annotations HOT 3
- Old-style decoration fails to raise on dataclasses since 0.2.24 HOT 1
- will runtime type checking go beyond function parameters and return type? HOT 9
- [DOC] Need better documentation about `from __future__ import annotations` HOT 3
- How can I inspect the jaxtyping bindings? HOT 2
- IPython `inspect.getsource()` failure due to incorrect co_firstlineno HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.