Comments (13)
Nice! Beartype also has similar plans: beartype/beartype#235
I'd be happy to add support for either/both when they're added. In fact, maybe it's worth asking if they could standardise on an API.
from jaxtyping.
In principle, anything is possible with monkey patching :)
In practice that was a crazy solution that I'm not keen to repeat!
from jaxtyping.
Absolutely! I completely agree.
So at the moment this is a limitation of the current approach. The checking is performed via isinstance
, which simply returns True
or False
, and it's then up to either typeguard or beartype to take this and turn it into an error message. This means that there isn't really any way of returning this additional information about why the isinstance
check failed.
I don't have a great solution in mind for this at the moment. I'd welcome any thoughts on how to accomplish this.
from jaxtyping.
I see: you're doing all your work at https://github.com/google/jaxtyping/blob/35201eb189cc004276925f96e0aa6bfc469e46be/jaxtyping/array_types.py#L102, and then typeguard says
elif not isinstance(value, expected_type):
raise TypeError(
'type of {} must be {}; got {} instead'.
format(argname, qualified_name(expected_type), qualified_name(value)))
Hmmm.
So it turns out this isn't too noisy, as when your check fails, we are almost certainly going to error:
class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
if not isinstance(obj, jnp.ndarray):
print(f'jaxtyping: {obj}:{type(obj)} is not a jnp.ndarray.')
return False
if cls.dtypes is not _any_dtype and obj.dtype not in cls.dtypes:
print(f'jaxtyping: {obj} dtype ({obj.dtype}) is not in {cls.dtypes}.')
return False
from jaxtyping.
Yeah, adding our own manual print statements might be one approach. Not super elegant of course so if we did this I'd probably add a global toggle on whether to print them out.
from jaxtyping.
Exactly so. It might even be a case for, ugh, an environment variable, so a usage pattern might be
% python t.py
...
Error message.
% JAXTYPING=verbose python t.py
from jaxtyping.
probably verbose should be the default? probably >90% of exceptions for a library like this one will be thrown while the dev is looking, not in some production use case where the print statement would be an issue.
that said, it probably should still print to stderr not stdout
from jaxtyping.
Hi @patrick-kidger - any updates on this? Feels like this makes jaxtyping a bit frustrating to use with a typechecker since shape mismatches are so common
from jaxtyping.
As it turns out, an analogous point has just been raised over on the beartype repo: beartype/beartype#216
If beartype includes a hook for this use case, then it's possible that we could add in some nicer error messages here.
Until then, my usual recommendation is to arrange to open a debugger when things crash (e.g. pytest --pdb
if using this as test time), and then just walk the strack trace looking at the object that was passed.
from jaxtyping.
@patrick-kidger thanks for the swift response! Crazy how that timing worked out.
Just out of curiosity, do you think patching typeguard like in torchtyping could work as a temporary solution? Not requesting to add it here but figured I'd ask since it looks complicated
from jaxtyping.
@patrick-kidger it looks like typeguard 4 is adding support for a typecheck fail callback (see for example https://github.com/agronholm/typeguard/blob/master/src/typeguard/_functions.py#L116-L144). Maybe jaxtyping could make use of this when it's released?
from jaxtyping.
Coming back to this, it looks like it might take quite a bit of time for beartype/typeguard to standardize their APIs, and implement them, so I think it would be nice to implement this, even if guarded by a global flag. I am guessing that a better solution that the one with printing could be decorating functions with another decorators, that would catch exceptions related to jaxtyping, and reraise them with better messages, while still preserving the original error message. Something like this:
@jaxtyping.pretty_errors
@beartype.beartype
@jaxtyping.jaxtyped
def f():
...
I imagine reraising could look similar to the jax errors, so that we have a "pretty" error printed after the original trace from the typechecker. Similar to this:
BeartypeTypeHintViolation: blah blah blah / TypeError: blah blah blah
The above exception was the direct cause of the following exception:
In the function 'f' argument 'x':
expected: Array["N", dtype=float]
got: Array["1,2", dtype=float]
argument 'y':
expected: Array["", dtype=int]
got: Array["", dtype=float]
The problem is that we will have to make a conditional based on whether the error is typeguard-raised or beartype-raised, or anything-else-raised, transform the culprit log into a unified format, and only then do a pretty printing.
When the official API is going to be implemented, we anyway will have to have a functionality for pretty printing, so implementing it beforehand does not look like a waste of work. And, even though it is an ugly (and unstable) solution, I am guessing that most of the users of jaxtyping would largely appreciate having this functionality available at hand. For example, in my case, the runtime type checking is mostly useful during prototyping/debugging, and this would save me quite a bit of time, since I would only need to take a quick look at the trace instead of inserting jax.debug.print("{x}", x=x)
in the place where 'f' is called from.
from jaxtyping.
FWIW we ended up implementing a small wrapper that does that for typeguard, it is literally ~30 lines of code (of which only 2 lines are typeguard specific, 15 lines do pretty printing, and the rest just boiler plate and comments) so instead of using
@jt.jaxtyped
we just use
@util.jaxtyped
from jaxtyping.
Related Issues (20)
- Feature request: Remove entry "modules" in function "install_import _hook()" HOT 2
- Venv `__pycache__` directories filling up HOT 1
- Failed to compile with Union HOT 3
- Type annotations must now include an explicit array type HOT 5
- Simple script error HOT 2
- A bug with typechecking fields of dataclasses with default values/factories HOT 14
- Support for jax.dtypes.prng_key, to denote jax.Arrays of PRNG keys, as in [JEP 9263](https://jax.rtfd.io/en/latest/jep/9263-typed-keys.html). HOT 2
- test_import_hook_transitive is flaky! HOT 2
- Order of symbolic expression evaluation HOT 2
- Module-level import hook HOT 1
- Support runtime type-checking of generic functions HOT 1
- Random key does not typecheck Key[Scalar, ""] HOT 2
- Weird `KeyError: '0'` when using None typechecker for `install_import_hook` HOT 4
- ImportError: cannot import name 'Array' from 'jaxtyping' HOT 2
- Symbolic expressions in argument annotations HOT 3
- Old-style decoration fails to raise on dataclasses since 0.2.24 HOT 1
- will runtime type checking go beyond function parameters and return type? HOT 9
- [DOC] Need better documentation about `from __future__ import annotations` HOT 3
- How can I inspect the jaxtyping bindings? HOT 2
- IPython `inspect.getsource()` failure due to incorrect co_firstlineno HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.