Comments (2)
Right, this is because jaxtyped
returns a new object, different to the underlying type.
(It was only designed to work on functions.)
We could probably extend it to classes in the following way:
def jaxtyped(fn):
if inspect.isclass(fn):
init = jaxtyped(fn.__init__)
fn.__init__ = init
return fn
else:
... # existing implementation
Give that a try and see if it works? If it does then I would be happy to accept the above as a PR.
from jaxtyping.
Yes, that works!
from jaxtyping import jaxtyped
#from beartype import beartype as typechecker
from typeguard import typechecked as typechecker
import inspect
def jaxtyped2(fn):
if inspect.isclass(fn):
init = jaxtyped(fn.__init__)
fn.__init__ = init
return fn
else:
jaxtyped(fn) # existing implementation
@jaxtyped2
@typechecker
class LinearGaussianSSM():
def __init__(
self,
state_dim: int,
emission_dim: int,
input_dim: int=0,
has_dynamics_bias: bool=True,
has_emissions_bias: bool=True
):
self.state_dim = state_dim
self.emission_dim = emission_dim
self.input_dim = input_dim
self.has_dynamics_bias = has_dynamics_bias
self.has_emissions_bias = has_emissions_bias
@typechecker
class LinearGaussianConjugateSSM(LinearGaussianSSM):
def __init__(self,
state_dim,
emission_dim,
input_dim=0,
has_dynamics_bias=True,
has_emissions_bias=True):
super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim,
has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias)
from jaxtyping.
Related Issues (20)
- BUG: Exception if array in args is modified in a function with the same name HOT 4
- Isinstance checks against a ShapeDtypeStruct HOT 2
- bug: can't type flax.struct.dataclass with vmapped functions HOT 3
- Disabling JAX import HOT 2
- einops-like packing notation HOT 4
- Symbolic expressions example doesn't run HOT 2
- v0.2.26 Release Failed HOT 1
- Allow two variadic shapes when it makes sense HOT 1
- Support for NestedTensors HOT 3
- `jaxtyped` Annotation fails
- jax dependency error when jax is not installed HOT 4
- Random instances / Hypothesis-like generation HOT 3
- Question: manual assertion HOT 4
- Move equinox "tree_pformat" into jaxtyping or allow users to configure their own HOT 1
- How to use with Sphinx autodoc? HOT 1
- Issues with torch.compile HOT 5
- Functions without type hints and import hook HOT 1
- Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' HOT 4
- Can typeguard be an optional dependency? HOT 8
- Are pytorch named tensors supported, like in torchtyping? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxtyping.