Git Product home page Git Product logo

Comments (10)

awf avatar awf commented on July 28, 2024 2

Thanks both for this discussion - I implemented something quick based on @patrick-kidger's spike above, and it seems to work quite nicely. Next step is to integrate with jaxtyping, but I thought I would put it out here...

https://github.com/awf/awfutils#typecheck

A fairly direct copy of your suggestion above...

https://github.com/awf/awfutils/blob/7359acb6528325f6770fc9c28aab86f548d22ad4/typecheck.py#L133

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024 1

So checking intermediate annotations like this is really the job of a runtime type checker (typeguard/beartype) -- just like how checking the argument/return annotations is already the job of a runtime type checker. In either case all jaxtyping does is provide isinstance-compatible JAX types. After all, you may equally well wish to check x: int = foo(), which is an operation unrelated to JAX types.

Probably what would be needed here would be a decorator that parses the abstract syntax tree for a function, detects annotations, and then inserts manual isinstance checks such as assert isinstance(x, int). This probably qualifies as "not too bad" for someone already familiar with AST rewriting; probably something like

import ast
import inspect
import beartype
import jaxtyping

def check_intermediate_annotations(fn):
  ast = ast.parse(inspect.getsource(fn))
  # rewrite ast using ast.NodeVisitor and ast.NodeTransformer
  return eval(ast.unparse(ast))

@jaxtyping.jaxtyped
@beartype.beartype
@check_intermediate_annotations
def bar(y):
  x: LxDk = foo()
  return x

Although I note that this kind of source-file parsing and unparsing is a bit fraught with edge cases (e.g. it won't work in a REPL). Unfortunately Python just doesn't provide a good way to handle this kind of thing.

Off the top of my head I don't know of a runtime type checker that does this, though. (CC @leycec for interest.)

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024 1

Right, so JAX sits in a lovely spot for applicability of runtime type checking, because Python is only ever being used as a metaprogramming language for XLA.

In this context I wouldn't worry about the extra runtime overhead.

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024 1

Not really documented yet, but for anyone coming across this issue: this now exists in beartype (beartype/beartype#7 (comment))!

from jaxtyping.

leycec avatar leycec commented on July 28, 2024

Fascinating feature request intensifies. Thanks so much for pinging me into the fray, @patrick-kidger. Exactly as you suspect, no actively maintained runtime type-checker that I know of currently performs static type-checking at runtime. Sad cat emoji is sad. ๐Ÿ˜ฟ

That said, @beartype does have an open feature request encouraging us to eventually do this. This is fun stuff, because it's hard stuff. Actually, it's not too hard to naively perform static type-checking at runtime by combining the fearsome powers of import hooks + abstract syntax tree (AST) inspection. But it's really hard to do so without destroying runtime performance โ€“ especially in pure Python. Extremely aggressive on-disk caching (e.g., like the sordid pile of JSON files that mypy dumps into project-specific .mypy_cache/ subdirectories) would certainly be a hard prerequisite.

I Don't Like What I'm Hearing

Until then, @beartype provides a reasonably well-documented procedural API for type-checking arbitrary things against arbitrary type hints at any time:

query : LxDk = linear(head.query, t1)
key : LxDk = linear(head.key, t1)

# Runtime type-check everything above.
from beartype.abby import die_if_unbearable
die_if_unbearable(query, LxDk)
die_if_unbearable(key, LxDk)

If that's a bit too much egregious boilerplate, consider wrapping the above calls to linear() with a @beartype-friendly factory: e.g.,

from beartype.abby import die_if_unbearable

def linear_typed(*args, **kwargs, hint: object):
    '''
    Linear JAX array runtime type-checked by the passed type hint.
    '''

    linear_array = linear(*args, **kwargs)
    die_if_unbearable(linear_array, hint)
    return linear_array

query : LxDk = linear_typed(head.query, t1, LxDk)
key : LxDk = linear_typed(head.key, t1, LxDk)

Technically, that still violates DRY a bit by duplicating the LxDk type hints. Pragmatically, that's probably as concise as we can manage... for the moment.

Let's pray I actually do something and make static runtime type-checking happen, everybody. ๐Ÿ˜ฎโ€๐Ÿ’จ

from jaxtyping.

awf avatar awf commented on July 28, 2024

Thanks @leycec ! Just to clarify on "destroying runtime performance", do you mean making it worse than

query : LxDk = linear(head.query, t1)
die_if_unbearable(query, LxDk)
key : LxDk = linear(head.key, t1)
die_if_unbearable(key, LxDk)

?

from jaxtyping.

awf avatar awf commented on July 28, 2024

Because the above seems acceptable to me, particularly under a jax-style define-by-run scheme.

And if it's not I might wrap die_if_unbearable to have logic like:

die_if_unbearable(query, LxDk) if rand() > 0.93

or

die_if_unbearable(query, LxDk) if (beartype_time/total_program_time < bearable_beartype_overhead or
                                   rand() > (beartype_time/total_program_time / bearable_beartype_overhead))

[I realise there are simplifications/corrections of the last logic, hope the sentiment is clear]

from jaxtyping.

leycec avatar leycec commented on July 28, 2024

Agh! I should be more explicit in my jargon, especially when slinging around suspicious phrases like "destroying runtime performance." So...

Just to clarify on "destroying runtime performance", do you mean making it worse than

query : LxDk = linear(head.query, t1)
die_if_unbearable(query, LxDk)
key : LxDk = linear(head.key, t1)
die_if_unbearable(key, LxDk)

Yes. Much, much worse than that. die_if_unbearable() does not destroy runtime performance, because all @beartype operations to date โ€“ including die_if_unbearable() and @beartype-decorated wrapper functions alike โ€“ exhibit constant-time O(1) runtime performance with negligible constant factors. It don't get faster than that, right? Literally.

By "destroy runtime performance," I was instead referring to the hypothetical crippling performance burden of doing static type-checking analysis at runtime via import hooks and AST inspection. @beartype don't do that yet; nobody does. The practical difficulties of optimizing static analysis at runtime is a Big Reasonโ„ข why.

But... someday @beartype or somebody else will go there. A runtime type-checker that efficiently performs static analysis at runtime would effectively obsolete standard static type-checkers (e.g., mypy, pyright) for most practical purposes. Since there is both Big Moneyโ„ข and Big Hypeโ„ข for building that field of dreams, it will happen... someday.

Until then, we collectively wish upon a rainbow. ๐ŸŒˆ

Because the above seems acceptable to me, particularly under a jax-style define-by-run scheme.

Absolutely. @beartype has been profiled to be disgustingly fast. That's the whole point, really. @beartype is actually two orders of magnitude faster than even pydantic, which is compiled down to C via Cython. Yeah. We're that fast.

You should never need to conditionally disable @beartype. If you do, bang on our issue tracker and we'll promptly resolve the performance regressions you are seeing. Until then, the best way to use @beartype is to just always use @beartype.

And if it's not I might wrap die_if_unbearable to have logic like:

die_if_unbearable(query, LxDk) if rand() > 0.93

...heh. Probabilistic runtime type-checking. Love it! I must acknowledge cleverness when I see it. Admittedly, that also makes my eye twitch spasmodically.

If you do end up profiling die_if_unbearable() for your particular use case, please post your timings. That function is a bit less optimized than it could be โ€“ mostly as I didn't realize there was actual demand for statement-level runtime type-checking.

Now I know. And knowledge is half the battle.

In this context I wouldn't worry about the extra runtime overhead.

These are sweet, soothing words. Please say more relieving things like this. ๐Ÿ˜Œ

from jaxtyping.

leycec avatar leycec commented on July 28, 2024

Extremely impressive. @awfutils.typecheck is the first practical attempt I've seen at performing static type-checking at runtime. Take my thunderous clapping! ๐Ÿ‘ ๐Ÿ‘ ๐Ÿ‘

Your current approach is outrageously useful, but appears to currently only support isinstance()-able classes rather than PEP-compliant type hints: e.g.,

# I suspect this fails hard, but am lazy and thus did not test.
@typecheck
def foo(x : List[int], y : int):
  z : List[int] = x * y
  w : float = z[0] * 3.2
  return w

foo([3, 2, 1], 1.3)

Is that right? If so, that's still impressive tech for a several hundred-line decorator. I'll open up a feature request on your issue tracker to see if we can't trivially generalize that to support all (...or at least most) PEP-compliant type hints, @awf.

In short, this is so good. \o/

from jaxtyping.

leycec avatar leycec commented on July 28, 2024

Indeed. As @patrick-kidger notes, our new beartype.claw API transforms @beartype into a hybrid runtime-static type-checker. This is the way:

# In your top-level "{your_package}.__init__" submodule:
from beartype.claw import beartype_this_package
beartype_this_package()

That's it. @beartype will now type-check statement-level annotations in concert with jaxtyping.

Not really documented it yet...

...yeah. Noticed that, huh? I've intentionally left beartype.claw undocumented for a bit. Technically, it's rock solid as is and "good enough" for most use cases and production workloads. Still, first impressions are everything; it'll really benefit from stability improvements in our upcoming @beartype 0.16.0 release โ€“ especially with respect to complex forward references (e.g., 'Dict[str, MuhGeneric[int]]') and PEP 563 (i.e., from __future__ import annotations). Thankfully, I've almost finalized @beartype 0.16.0 and expect it to land in a week or two.

Until then, one-liners for great QA justice! ๐Ÿ’ช ๐Ÿป

from jaxtyping.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.