Git Product home page Git Product logo

Comments (8)

MilesCranmer avatar MilesCranmer commented on July 28, 2024 3

Ran into this as well for pytorch. For me the solution as described on https://docs.kidger.site/jaxtyping/api/array/#array was to use torch.Tensor in place of jaxtyping.Array, like so:

from torch import Tensor
from jaxtyping import Float32

def f(x: Float32[Tensor, "dim1 dim2"]) -> Float32[Tensor, "dim1 dim2"]:
    return x

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024

You need to have JAX installed as well.

jaxtyping only has JAX as an optional dependency, to support also being used with PyTorch etc.

from jaxtyping.

danbider avatar danbider commented on July 28, 2024

@MilesCranmer thanks. @patrick-kidger the jax requirement was relaxed? I don't see it anymore in pyproject.toml.
If so i'll modify my code according to the syntax suggested by Miles

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024

Sorry, missed this question. Yes, jaxtyping no longer depends on JAX. The name is now for historical reasons only! The syntax Miles is using is correct.

from jaxtyping.

pbsds avatar pbsds commented on July 28, 2024

When authoring ML runtime agnostic tooling, such as a dataset, what is the correct array type to use? I cannot assume I have neither torch, jax nor tensorflow. I currently assume at least numpy and do the following, but it might not work for other use cases:

from typing import Union, TYPE_CHECKING
from jaxtyping import Float, Bool
if TYPE_CHECKING:
    from torch import Tensor
    from numpy import ndarray
    from jaxtyping import Array as JaxArray
    # TODO: tensorflow
    Array = Union[Tensor, ndarray, JaxArray]
else:
    from numpy import ndarray as Array

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024

Probably something like this:

from typing import Union, TYPE_CHECKING
if TYPE_CHECKING:
    from torch import TorchTensor
    from numpy import ndarray
    from jaxtyping import Array as JaxArray
    from tensorflow import TfTensor
    Array = Union[TorchTensor, ndarray, JaxArray, TfTensor]
else:
    arrays = []
    try:
        from torch import Tensor as TorchTensor
    except Exception:
        pass
    else:
        arrays.append(TorchTensor)
    try:
        from numpy import ndarray
    except Exception:
        pass
    else:
        arrays.append(ndarray)
    try:
        from jaxtyping import Array as JaxArray
    except Exception:
        pass
    else:
        arrays.append(JaxArray)
    try:
        from tensorflow import Tensor as TfTensor
    except Exception:
        pass
    else:
        arrays.append(TfTensor)
    Array = Union[tuple(arrays)]

from jaxtyping.

pbsds avatar pbsds commented on July 28, 2024

Neat! I'd go for except (ModuleNotFoundError, ImportError): 😉

And it doesn't exactly roll of the tongue. Any chance this could be added to jaxtyping?

from jaxtyping.

patrick-kidger avatar patrick-kidger commented on July 28, 2024

Actually, the more general Exception is deliberate. There are cases when try-importing a module can result in other issues too, c.f. https://github.com/google/jaxtyping/blob/7a84b27da9e57c425ce4e6333121c3cdf2e07302/jaxtyping/_array_types.py#L33-L39

As for adding the above to jaxtyping. jaxtyping tries to essentially be backend-agnostic. In particular, I don't think I'd want to hardcode that it'll look for specifically torch+numpy+tensorflow+jax and nothing else. As such I think something like this is out-of-scope for jaxtyping, I'm afraid.

from jaxtyping.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.