Comments (6)
I think both of these features should already exist!
I think the min
notation should work as-is: see "symbolic expression" here.
Broadcasting can be notated with a #
. See the above link, and in particular the example for broadcasting over multiple dimensions i.e. def add(x: Float[Array, "*#foo"], y: Float[Array, "*#foo"]) -> Float[Array, "*#foo"]
, which I think is what you're aiming to do.
from jaxtyping.
Very cool! When reading the docs I had thought that expressions only included arithmetic operators and did not extend to functions. For the batch stuff I am looking for, unfortunately, I don't think the existing broadcasting operations will handle what I need. To be explicit. In that library, a matrix can also have batch dimensions indicating multiple sets of arrays that are being worked upon. So, when we form C= A*B we may have
dim(A) = (3, 4, 5, 6)
dim(B) = (6, 7)
and we expect dim(C) = (3, 4, 5, 7).
Basically, what I think I need is something like either a coalesce function or an optional operator. The idea of the 'optional' operator would be to check the value of the variable if it is defined and ignore it otherwise. So, if I define '?' to be the 'optional' operator I could write:
def _matmul(
self: Float[LinearOperator, "*batchA M N"],
rhs: Float[torch.Tensor, "*batchB N C"],
) -> Float[torch.Tensor, "?batchA ?batchB M C"]:
It doesn't look like there is a direct way to write a coalesce function, but I am tempted to change the eval call in _check_dims to pass in obj_shape. Then, a custom function like coalesce can be written to decide if the dimension check passes. But this may not play well with variables that hold multiple dimensions - I guess I may need to experiment. Thanks again for pointing out the eval feature!
from jaxtyping.
I think what you're looking for is for your a return annotation of Float[torch.Tensor, "*batchA *batchB M C"]
.
Note in particular that *
may bind against zero dimensions, so batchA
and batchB
will always be defined.
So I think this already exists in terms of being able to represent it in the "shape mini language". There is one wrinkle - in practice, multiple *
are not supported in a single annotation, as this needs a more complicated checking algorithm. (In general it would need to keep track of multiple possibilities, rather than just a single one.)
If you want I'd be happy to accept a PR on this for at least the limited case you're interested in: support multiple *
provided the dimensions are already defined. (In this case, they're already defined from the input.)
from jaxtyping.
The min(A,B)
feature would be cool indeed, but it does not work at the moment, because it is prevented by the common mistakes check for ,
usage:
ValueError: Dimensions should be separated with spaces, not commas
Maybe that check could be disabled if the expression starts with min
or max
? Are there any other useful operations with a comma?
I quite like the idea of an optional operator A ?B C
. It seems very intuitive and I expect it to come in handy occasionally. But I can also see that it would substantially complicate the shape checking. Probably not worth the effort.
from jaxtyping.
Hmm, that's a good point about the comma checking getting in the way. A simple fix might just be some heuristic like "is this comma also surrounded by a matching number of brackets?" Or something similar.
WDYT? I'd be happy to take a PR on this.
from jaxtyping.
Fixed in v0.2.15!
from jaxtyping.
Related Issues (20)
- Feature request: Remove entry "modules" in function "install_import _hook()" HOT 2
- Venv `__pycache__` directories filling up HOT 1
- Failed to compile with Union HOT 3
- Type annotations must now include an explicit array type HOT 5
- Simple script error HOT 2
- A bug with typechecking fields of dataclasses with default values/factories HOT 14
- Support for jax.dtypes.prng_key, to denote jax.Arrays of PRNG keys, as in [JEP 9263](https://jax.rtfd.io/en/latest/jep/9263-typed-keys.html). HOT 2
- test_import_hook_transitive is flaky! HOT 2
- Order of symbolic expression evaluation HOT 2
- Module-level import hook HOT 1
- Support runtime type-checking of generic functions HOT 1
- Random key does not typecheck Key[Scalar, ""] HOT 2
- Weird `KeyError: '0'` when using None typechecker for `install_import_hook` HOT 4
- ImportError: cannot import name 'Array' from 'jaxtyping' HOT 2
- Symbolic expressions in argument annotations HOT 3
- Old-style decoration fails to raise on dataclasses since 0.2.24 HOT 1
- will runtime type checking go beyond function parameters and return type? HOT 9
- [DOC] Need better documentation about `from __future__ import annotations` HOT 3
- How can I inspect the jaxtyping bindings? HOT 2
- IPython `inspect.getsource()` failure due to incorrect co_firstlineno HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.