Comments (2)
With a bit of trickery, this is possible!
from typing import TypeVar, Union, TYPE_CHECKING from jaxtyping import Array, Float32, Shaped if TYPE_CHECKING: # needed for static type checking compatibility class _Unused: pass T = TypeVar("T") Batched = Union[T, _Unused] else: class Batched: def __class_getitem__(cls, item): return Shaped[item, "N"]This doesn't support
Batched[PyTree[...]]
though; you should write that asPyTree[Batched[...]]
. (With a bit of hackery you could detectBatched[PyTree[...]]
in__class_getitem__
and convert it, if you really wanted to.)
Hi Patrick, thanks a lot for your quick reply! This looks great. I was playing around with this since I also have some PyTree[leaf]
type-aliases, so I modified the class-getitem dunder by extracting the PyTree leaf if possible.
class Batched:
def __class_getitem__(cls, item):
if hasattr(item, 'leaftype'):
return PyTree[Shaped[item.leaftype, 'N']]
return Shaped[item, 'N']
from jaxtyping.
With a bit of trickery, this is possible!
from typing import TypeVar, Union, TYPE_CHECKING
from jaxtyping import Array, Float32, Shaped
if TYPE_CHECKING: # needed for static type checking compatibility
class _Unused:
pass
T = TypeVar("T")
Batched = Union[T, _Unused]
else:
class Batched:
def __class_getitem__(cls, item):
return Shaped[item, "N"]
This doesn't support Batched[PyTree[...]]
though; you should write that as PyTree[Batched[...]]
. (With a bit of hackery you could detect Batched[PyTree[...]]
in __class_getitem__
and convert it, if you really wanted to.)
from jaxtyping.
Related Issues (20)
- v0.2.26 Release Failed HOT 1
- Allow two variadic shapes when it makes sense HOT 1
- Support for NestedTensors HOT 3
- `jaxtyped` Annotation fails
- jax dependency error when jax is not installed HOT 4
- Random instances / Hypothesis-like generation HOT 3
- Question: manual assertion HOT 4
- Move equinox "tree_pformat" into jaxtyping or allow users to configure their own HOT 1
- How to use with Sphinx autodoc? HOT 1
- Issues with torch.compile HOT 5
- Functions without type hints and import hook HOT 1
- Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' HOT 4
- Can typeguard be an optional dependency? HOT 8
- Are pytorch named tensors supported, like in torchtyping? HOT 1
- How to properly escape `*` and `_` when rendering docs with Sphinx HOT 3
- numpy structured dtype support HOT 1
- Bug with default argument binding HOT 2
- Incompatibility with flax.linen.tabulate HOT 4
- Unions not working HOT 4
- jaxtyping with JAX severely slowing down training speed HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.