Comments (6)
I think both of these features should already exist!
I think the min
notation should work as-is: see "symbolic expression" here.
Broadcasting can be notated with a #
. See the above link, and in particular the example for broadcasting over multiple dimensions i.e. def add(x: Float[Array, "*#foo"], y: Float[Array, "*#foo"]) -> Float[Array, "*#foo"]
, which I think is what you're aiming to do.
from jaxtyping.
Very cool! When reading the docs I had thought that expressions only included arithmetic operators and did not extend to functions. For the batch stuff I am looking for, unfortunately, I don't think the existing broadcasting operations will handle what I need. To be explicit. In that library, a matrix can also have batch dimensions indicating multiple sets of arrays that are being worked upon. So, when we form C= A*B we may have
dim(A) = (3, 4, 5, 6)
dim(B) = (6, 7)
and we expect dim(C) = (3, 4, 5, 7).
Basically, what I think I need is something like either a coalesce function or an optional operator. The idea of the 'optional' operator would be to check the value of the variable if it is defined and ignore it otherwise. So, if I define '?' to be the 'optional' operator I could write:
def _matmul(
self: Float[LinearOperator, "*batchA M N"],
rhs: Float[torch.Tensor, "*batchB N C"],
) -> Float[torch.Tensor, "?batchA ?batchB M C"]:
It doesn't look like there is a direct way to write a coalesce function, but I am tempted to change the eval call in _check_dims to pass in obj_shape. Then, a custom function like coalesce can be written to decide if the dimension check passes. But this may not play well with variables that hold multiple dimensions - I guess I may need to experiment. Thanks again for pointing out the eval feature!
from jaxtyping.
I think what you're looking for is for your a return annotation of Float[torch.Tensor, "*batchA *batchB M C"]
.
Note in particular that *
may bind against zero dimensions, so batchA
and batchB
will always be defined.
So I think this already exists in terms of being able to represent it in the "shape mini language". There is one wrinkle - in practice, multiple *
are not supported in a single annotation, as this needs a more complicated checking algorithm. (In general it would need to keep track of multiple possibilities, rather than just a single one.)
If you want I'd be happy to accept a PR on this for at least the limited case you're interested in: support multiple *
provided the dimensions are already defined. (In this case, they're already defined from the input.)
from jaxtyping.
The min(A,B)
feature would be cool indeed, but it does not work at the moment, because it is prevented by the common mistakes check for ,
usage:
ValueError: Dimensions should be separated with spaces, not commas
Maybe that check could be disabled if the expression starts with min
or max
? Are there any other useful operations with a comma?
I quite like the idea of an optional operator A ?B C
. It seems very intuitive and I expect it to come in handy occasionally. But I can also see that it would substantially complicate the shape checking. Probably not worth the effort.
from jaxtyping.
Hmm, that's a good point about the comma checking getting in the way. A simple fix might just be some heuristic like "is this comma also surrounded by a matching number of brackets?" Or something similar.
WDYT? I'd be happy to take a PR on this.
from jaxtyping.
Fixed in v0.2.15!
from jaxtyping.
Related Issues (20)
- BUG: Exception if array in args is modified in a function with the same name HOT 4
- Isinstance checks against a ShapeDtypeStruct HOT 2
- bug: can't type flax.struct.dataclass with vmapped functions HOT 3
- Disabling JAX import HOT 2
- einops-like packing notation HOT 4
- Symbolic expressions example doesn't run HOT 2
- v0.2.26 Release Failed HOT 1
- Allow two variadic shapes when it makes sense HOT 1
- Support for NestedTensors HOT 3
- `jaxtyped` Annotation fails
- jax dependency error when jax is not installed HOT 4
- Random instances / Hypothesis-like generation HOT 3
- Question: manual assertion HOT 4
- Move equinox "tree_pformat" into jaxtyping or allow users to configure their own HOT 1
- How to use with Sphinx autodoc? HOT 1
- Issues with torch.compile HOT 5
- Functions without type hints and import hook HOT 1
- Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' HOT 4
- Can typeguard be an optional dependency? HOT 8
- Are pytorch named tensors supported, like in torchtyping? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.