Comments (12)
Thank you both for your thoughts. :)
Personally I agree with both of you; I'd much rather have Float["foo bar"]
or Array[Float, "foo bar"]
as well. However the Float[Array, "foo bar"]
syntax is the best approach I've been able to find that has any compatibility with static type checkers.
This is necessary since jax.typing
will happen at some point. They're happy to work with jaxtyping
to ensure we don't bifurcate into incompatible approaches, but static type checking is their goal, so we need to work with them in turn. (I'm all ears to any suggestions that will pass both mypy+pytype.)
[+I'd quite like to support static type checking as well -- it's a pretty useful thing, when it works.]
@thomaspinder - regarding point 4, you mean that you would rather see Float32
over f32
? (I don't have a strong preference myself, so I'd be happy to change this.)
@ayaka14732 - indeed, specifying numpy/JAX/etc. backend is also desirable. It's a lesser priority, but supporting this is also a goal of the proposed new syntax. (Note that it's quite difficult to have optional arguments for static type checkers; off the top of my head the only things which support this are typing.{Literal, Annotated, Tuple}
, and variadic generics. An Annotated
solution is probably the most likely to work.)
from jaxtyping.
Interesting. Thanks for the clarification. If it enables static type checking, then the Float[Array, "foo bar"]
syntax seems like a necessary sacrifice.
@patrick-kidger Yes - I, personally, would prefer Float32
over f32
.
from jaxtyping.
jaxtyping
and jax.typing
- I hope that static type checking is not a main goal, if it would make things more complicated. I also think that compatibility is not necessary, since it would often increase the complexity. And having two separate projects would not be a bad thing, because we are free to choose between the two projects.
- As a consequence, the project name would have to be changed. Moreover, considering that we will be supporting not only
jnp.array
, but also NumPy/PyTorch/TensorFlow as well, the project name may not contain the word JAX. - I suggest
tensortyper
as the new project name. It is easy to pronounce since both words starts with t and ends with r.
Syntax changes
- I think longer names like
Float
are better than shoter ones likef
because 'explicit is better than implicit'. And if we really want to use shorter names, we can always dofrom jaxtyping import Float as f
. - If we are going to use the syntax
Float[Array, "foo bar"]
, I hope that it could beFloatArray[Array, "foo bar"]
. This is because the former sounds like a certain type of float instead of an array, similar to thatList[int]
is a certain type of list. - I prefer
IntSign
/SignInt
toIntS
/SInt
because 'explicit is better than implicit'. Besides, I think it should beIntSigned
/SignedInt
because 'signed' is the adjective.
from jaxtyping.
Would Float[jnp.ndarray, "batch length channels"]
be a bit confusing? I am thinking that something like jnp.ndarray[Float, "batch length channels"]
or FloatArray[jnp.ndarray, "batch length channels"]
would look more intuitive
from jaxtyping.
Whilst by no means a strong opinion, I prefer Float["foo"] over Float[Array, "foo"]. Everything in GPJax where we use JaxTyping is an Array, so the additional line is somewhat superfluous.
A common use case I can think of is that when the dataset is too large to fit into the TPU memory, we usually load the entire dataset into the CPU memory as a NumPy array first, then slice the dataset into small batches and convert them to JAX arrays on TPU. In this case it would be necessary to distinguish between np.array
and jnp.array
. (See google/jax#8933 (comment))
Besides, I am also wondering if it is possible to make np.array
/jnp.array
an optional second argument, so that we can just omit it if we are indifferent to the type.
from jaxtyping.
CC @leycec for curiosity, as someone who evidently has a lot of strong opinions on type-checking. (Also you might like the trick through which Float[Array, "foo"]
is made to work with static type checkers.)
from jaxtyping.
Thanks for the heads up @patrick-kidger. I'm OK with the look of both proposed changes, though I do particularly like the explicitness of the former.
To answer your specific questions in order:
- Personally, I'm fine with this.
- I like the verbosity of
IntSign
andIntUnsign
. It is a few more characters, but it mitigates any confusion i.e.,IntS
could be misconstrued as multipleInt
s. - For my above reasoning, I like
Float
- As above.
- Whilst by no means a strong opinion, I prefer
Float["foo"]
overFloat[Array, "foo"]
. Everything in GPJax where we use JaxTyping is an Array, so the additional line is somewhat superfluous.
As an aside, it would be cool to see more mathematically inspired typing systems e.g., ColumnVector['N']
would type a Nx1 vector. Extensions would be RowVector
, Scalar
, Matrix
and Tensor
. I have no idea how feasible this is and how you feel about it but I'm happy to open a separate issue where we can discuss further if it is of interest though.
from jaxtyping.
Alternate project names? Having both jax.typing and jaxtyping is a bit confusing. (But perhaps not too bad, if jaxtyping is a strict superset.)
Taking inspiration from typing-extensions official python package, maybe rename jaxtyping
to jaxtyping-extensions
?
from jaxtyping.
Thanks for this RFC @patrick-kidger! Regarding the naming for types, I'd say "adherence to the NumPy API" to be one axis to evaluate syntax changes on. JAX's own potential stems from "it's just NumPy", and this would help onboard new users as quickly as possible. By that line of reasoning:
- Verbosity: Types should be more verbose; not just NumPy, but also Python's own "explicit is better than implicit" plays into
Float32
feeling more "pythonic" thanf32
. - Specific Cases: Extending this reasoning, I'd prefer to see syntax such as
Int
andUInt
overIntSign
andIntUnsign
, since they're closer to the establishednp.int
andnp.uint
. - Imports: A question I've been thinking is: should types be direct imports (eg.
from typing import Tuple
) or module-level imports (e.g.import numpy as np
, then usenp.float32
)? I could easily see eitherFloat
orjxt.float
style syntax working (wherejxt
would beimport jaxtyping as jxt
). After discussion, I personally would lean towards the former (direct imports), since these classes exist to aid type-checking. - Open Question: How challenging would it be under Python to have a type such as
Array[Float, "b t d"]
as opposed toFloat[Array, "b t d"]
as proposed here? I'm curious if this is a design choice, or reflection of more fundamental limitations in the language, since the former would be closer to how other languages (such as Julia'sArray{Float32}
) present their types.
I'll add this: I'm enjoying using this library; the multi-argument runtime type-checking over tensors is a brilliant idea, makes me go "how was I not checking this before?". The fact it just works with all the other JAX transforms (jit
, pmap
, etc.) is equally neat. The syntax changes shouldn't take away from the underlying achievement that, even as the library is right now, it works well! As this RFC is elaborated on, I'm curious to see the fleshing out of the "default, recommended" way to use jaxtyping
going forward.
from jaxtyping.
Thanks all for your feedback. The new version has now been merged: #16.
(In particular I appreciate all the positive feedback to the effect of "jaxtyping is super useful, thanks" -- this warms a library author's heart.)
To respond to the last round of comments raised:
- I liked the suggestion of @PhilipVinc to rename to
jaxtyping-extensions
, and came close to doing this. In the end I decided against this just because a shorter name is honestly a little more marketable. (+one fewer backward-incompatible change) - The general feedback was a preference for
Float32
overFloat
, so that has now happened. - Lots of different ideas on what to call the integer types. In the end I decided to follow NumPy:
Int
for signed integers,UInt
for unsigned integers, andInteger
for both. - With apologies to @ayaka14732, the goal of these changes was in large part to provide compatibility with static typing, which is important for compatibility with core JAX.
- @irhum - indeed
Array[Float, "foo bar"]
would have been a nice syntax, but unfortunately this is incompatible with static typing.
Once again, thankyou everyone for your engagement!
from jaxtyping.
@patrick-kidger
I am sorry if my question is out of scope, but I have a question on multi-arguments type checking as I see this concept is mentioned in the thread:
in the context of tensor shape, what is something special about multi-arguments type checking (v.s. single argument type checking)? Would you mind briefly explaining how it works differently under multi-arguments? (I couldn't find related references with a few searches) And why it has to be decorated by @jaxtyped
and @typechecker
?
from jaxtyping.
The multi-argument checking is in reference to checking that, across multiple arguments to a function, the sizes of multiple arrays should agree.
For example def foo(x: Shaped[Array "bar"], y: Shaped[Array, "bar"])
should be called with two arrays of the same size.
from jaxtyping.
Related Issues (20)
- Allow two variadic shapes when it makes sense HOT 1
- Support for NestedTensors HOT 3
- `jaxtyped` Annotation fails
- jax dependency error when jax is not installed HOT 4
- Random instances / Hypothesis-like generation HOT 3
- Question: manual assertion HOT 4
- Move equinox "tree_pformat" into jaxtyping or allow users to configure their own HOT 1
- How to use with Sphinx autodoc? HOT 1
- Issues with torch.compile HOT 5
- Functions without type hints and import hook HOT 1
- Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' HOT 4
- Can typeguard be an optional dependency? HOT 8
- Are pytorch named tensors supported, like in torchtyping? HOT 1
- How to properly escape `*` and `_` when rendering docs with Sphinx HOT 3
- numpy structured dtype support HOT 1
- Bug with default argument binding HOT 2
- Incompatibility with flax.linen.tabulate HOT 4
- Unions not working HOT 4
- jaxtyping with JAX severely slowing down training speed HOT 6
- `install_import_hook` skip `no_type_check` HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.