Git Product home page Git Product logo

sympy2jax's Introduction

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)

Where:

  • expressions is a PyTree of SymPy expressions.
  • extra_funcs is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
  • make_array is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

Finally

See also: other libraries in the JAX ecosystem

jaxtyping: type annotations for shape/dtype of arrays.

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Disclaimer

This is not an official Google product.

sympy2jax's People

Contributors

calbach avatar denehoffman avatar patrick-kidger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sympy2jax's Issues

sympy.Array support

It would be possible to add support for sympy.Array objects? It might be useful in case like:

from sympy import Array
from sympy.abc import x
from sympy2jax import SymbolicModule

a = Array([1,2,3])
e = a * x
j = SymbolicModule(e)
j(x=2)

# Output --> jax.DeviceArray([2, 4, 6], dtype=int64)

Obviously, when there is no the possibility to give an array directly as "subs" to the SymbolicModule.

Thanks you for the answer!

Physics-Informed Example

Hello,

This library looks super interesting! I'd like to get a better understanding of how to potentially use it for my research with physics-informed models. Would you please be able to provide a small example of building such a model with sympy2jax side-by-side with the equivalent model using only Equinox?

Thank you!

Crashes with imaginary numbers

I tried converting my complex sympy expression to jax, and got the following error.

I wrote a minimum working example. The I is sympy's variable for a complex number. 1j is Python's version, and they are both treated the same.

from sympy import symbols, I
import sympy2jax

x = symbols("x")

expr = x*I # or x*1j

sympy2jax.SymbolicModule(expr)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    212 try:
--> 213     return memodict[expr]
    214 except KeyError:

KeyError: I*x

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    212 try:
--> 213     return memodict[expr]
    214 except KeyError:

KeyError: I

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:180, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    179 try:
--> 180     self._func = func_lookup[expr.func]
    181 except KeyError as e:

KeyError: <class 'sympy.core.numbers.ImaginaryUnit'>

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
/Users/thomas/Documents/vilde.ipynb Cell 6 in <cell line: 8>()
      [4](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=3) x = symbols("x")
      [6](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=5) expr = x*I # or x*1j
----> [8](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=7) sympy2jax.SymbolicModule(expr)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:257, in SymbolicModule.__init__(self, expressions, extra_funcs, make_array, **kwargs)
    250     self.has_extra_funcs = True
    251 _convert = ft.partial(
    252     _sympy_to_node,
    253     memodict=dict(),
    254     func_lookup=lookup,
    255     make_array=make_array,
    256 )
--> 257 self.nodes = jax.tree_map(_convert, expressions)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in tree_map(f, tree, is_leaf, *rest)
    203 leaves, treedef = tree_flatten(tree, is_leaf)
    204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in <genexpr>(.0)
    203 leaves, treedef = tree_flatten(tree, is_leaf)
    204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    222     out = _Rational(expr, make_array)
    223 else:
--> 224     out = _Func(expr, memodict, func_lookup, make_array)
    225 memodict[expr] = out
    226 return out

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:183, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    181 except KeyError as e:
    182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
--> 183 self._args = [
    184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:184, in <listcomp>(.0)
    181 except KeyError as e:
    182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
    183 self._args = [
--> 184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    222     out = _Rational(expr, make_array)
    223 else:
--> 224     out = _Func(expr, memodict, func_lookup, make_array)
    225 memodict[expr] = out
    226 return out

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:182, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    180     self._func = func_lookup[expr.func]
    181 except KeyError as e:
--> 182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
    183 self._args = [
    184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

KeyError: "Unsupported Sympy type <class 'sympy.core.numbers.ImaginaryUnit'>"

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.