Git Product home page Git Product logo

sympytorch's Introduction

sympytorch

Turn SymPy expressions into PyTorch Modules.

SymPy floats (optionally) become trainable parameters. SymPy symbols are inputs to the Module.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympytorch

Requires Python 3.7+ and PyTorch 1.6.0+ and SymPy 1.7.1+.

Example

import sympy, torch, sympytorch

x = sympy.symbols('x_name')
cosx = 1.0 * sympy.cos(x)
sinx = 2.0 * sympy.sin(x)
mod = sympytorch.SymPyModule(expressions=[cosx, sinx])

x_ = torch.rand(3)
out = mod(x_name=x_)  # out has shape (3, 2)

assert torch.equal(out[:, 0], x_.cos())
assert torch.equal(out[:, 1], 2 * x_.sin())
assert out.requires_grad  # from the two Parameters initialised as 1.0 and 2.0
assert {x.item() for x in mod.parameters()} == {1.0, 2.0}

API

sympytorch.SymPyModule(*, expressions, extra_funcs=None)

Where:

  • expressions is a list of SymPy expressions.
  • extra_funcs is a dictionary mapping from custom sympy.Functions to their PyTorch implementation. Defaults to no extra functions.

Instances of SymPyModule can be called, passing the values of the symbols as in the above example.

SymPyModule has a method .sympy(), which returns the corresponding list of SymPy expressions. (Which may not be the same as the expressions it was initialised with, if the values of its Parameters have been changed, i.e. have been learnt.)

Wrapping floats in sympy.UnevaluatedExpr will cause them not to be trained, by registering them as buffers rather than parameters.

sympytorch.hide_floats(expression)

As a convenience, hide_floats will take an expression and return a new expression with every float wrapped in a sympy.UnevaluatedExpr, so that it is interpreted as a buffer rather than a parameter.

Extensions

Not every PyTorch or SymPy operation is supported -- just the ones that I found I've needed! There's a dictionary here that lists the supported operations. Feel free to submit PRs for any extra operations you think should be in by default. You can also use the extra_funcs argument to specify extra functions, including custom functions.

sympytorch's People

Contributors

casellimarco avatar fpollock avatar ipcamit avatar patrick-kidger avatar qthequartermasterman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sympytorch's Issues

Support sympy's `Zero`

Hey! thanks a lot for this package!:)

I just rand into what I think might be a minor bug:

In [1]: from sympytorch import SymPyModule

In [2]: import sympy

# this works
In [6]: m = SymPyModule(expressions=[sympy.core.numbers.Float(2.0)])

In [7]: m()
Out[7]: tensor([2.], grad_fn=<StackBackward0>)

# this does not
In [4]: m = SymPyModule(expressions=[sympy.core.numbers.Zero()])

In [5]: m()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 m()

File ~/Library/Application Support/hatch/env/virtual/qucint/6NOL9orC/qucint/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Library/Application Support/hatch/env/virtual/qucint/6NOL9orC/qucint/lib/python3.9/site-packages/sympytorch/sympy_module.py:191, in SymPyModule.forward(self, **symbols)
    189 def forward(self, **symbols):
    190     out = [node(symbols) for node in self._nodes]
--> 191     out = torch.broadcast_tensors(*out)
    192     return torch.stack(out, dim=-1)

File ~/Library/Application Support/hatch/env/virtual/qucint/6NOL9orC/qucint/lib/python3.9/site-packages/torch/functional.py:74, in broadcast_tensors(*tensors)
     72 if has_torch_function(tensors):
     73     return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 74 return _VF.broadcast_tensors(tensors)

TypeError: expected Tensor as element 0 in argument 0, but got int

If you have an idea where to start looking I could also take a stab at it.

not an issue - just some research in turning pdf -> code

I've been playing around and came across this today
it's possible this repo could take things one step further.
can you can train chat models with your own data - using pinecone -
it would probably be possible to have similar results to below without too much effort but with sympytorch.

https://pdf2gpt.com/

https://arxiv.org/pdf/2307.16171.pdf

ask question
can you provide latex equations from doc?

Pages 1โ€”2: Equations, Latex, summarization
Here are some latex equations from the document:


Equation (1):
$\[
\log p_{\theta}(x|c) \geq \mathbb{E}_{q_{\phi}(z|x)} \left[ \log p_{\theta_d}(x|z_a) - \log q_{\phi_a}(z_a|x_{\text{spec}}) + \log p_{\theta_a}(z_a|z_l) - \log q_{\phi_l}(z_l|x_{w2v}) \right]
\]$

Equation (2):
$\[
L_{\text{pitch}} = \|\log p_x - G_s(z_a, s)\|_1
\]$

Equation (3):
$\[
L_{\text{STFT}} = \|\psi(x) - \psi(G_w(z_a, p_h, s))\|_1
\]$

Equation (4):
$\[
L_{\text{adv}}(D) = \mathbb{E}_{x,z_a} \left[ (D(x)-1)^2 + D(G(z_a, s))^2 \right]
\]$

Equation (5):
$\[
L_{\text{adv}}(\phi_a, \theta_d) = \mathbb{E}_{z_a} \left[ (D(G(z_a, s))-1)^2 \right]
\]$

chatgpt
can you provide some sympy code for following equations - with some dummy data - Equation (1):

Sure, I can provide SymPy code for each of the equations with some dummy data. Note that SymPy is mainly used for symbolic mathematics, so we'll be creating symbolic representations of the equations rather than solving them numerically.

Before we start, let's define the symbols used in the equations:

  • x: Data variable.
  • c: Some context.
  • p_theta, p_theta_d, p_theta_a: Probability distributions with parameters theta, theta_d, and theta_a, respectively.
  • q_phi, q_phi_a, q_phi_l: Probability distributions with parameters phi, phi_a, and phi_l, respectively.
  • z: Latent variable.
  • z_a: Latent variable related to the attribute a.
  • z_l: Latent variable related to the label l.
  • x_spec, x_w2v: Some specific data or context for x and w2v (word2vec) models, respectively.
  • L_pitch, L_STFT, L_adv: Different loss terms.
  • G_s, G_w: Functions that take some inputs and generate data.

Now let's define the equations in SymPy:

import sympy as sp

# Define symbols
x, c, z, z_a, z_l, x_spec, x_w2v, p_theta, p_theta_d, p_theta_a = sp.symbols('x c z z_a z_l x_spec x_w2v p_theta p_theta_d p_theta_a')
q_phi, q_phi_a, q_phi_l = sp.symbols('q_phi q_phi_a q_phi_l')
s, p_h = sp.symbols('s p_h')
D = sp.symbols('D', cls=sp.Function)

# Equation (1)
eq1 = sp.log(p_theta(x|c)) >= sp.E**sp.integrate(sp.log(p_theta_d(x|z_a)) - sp.log(q_phi_a(z_a|x_spec)) + sp.log(p_theta_a(z_a|z_l)) - sp.log(q_phi_l(z_l|x_w2v)), (z, q_phi(z|x)))

# Equation (2)
G_s = sp.Function('G_s')
L_pitch = sp.norm(sp.log(x) - G_s(z_a, s), 1)

# Equation (3)
G_w = sp.Function('G_w')
psi = sp.Function('psi')
L_STFT = sp.norm(psi(x) - psi(G_w(z_a, p_h, s)), 1)

# Equation (4)
L_adv_D = sp.E**sp.integrate((D(x) - 1)**2 + D(G(z_a, s))**2, (x, z_a))

# Equation (5)
L_adv_phi_theta = sp.E**sp.integrate((D(G(z_a, s)) - 1)**2, (z_a, q_phi_a(z_a|x_spec)))

# Print equations
print("Equation (1):")
sp.pprint(eq1)

print("\nEquation (2):")
sp.pprint(L_pitch)

print("\nEquation (3):")
sp.pprint(L_STFT)

print("\nEquation (4):")
sp.pprint(L_adv_D)

print("\nEquation (5):")
sp.pprint(L_adv_phi_theta)

Please note that these equations will not be solvable in the usual sense due to the presence of integrals and probability distributions with unknown functional forms. But the above code will help you create symbolic representations of the equations for further manipulations or analyses in SymPy.

I did another one with physics -- you can prompt gpt to dream up applicability / dummy values. here I gave it direct equation - and ask it to explain it in children terms.
https://gist.github.com/johndpope/ab1f334b615222b81505b6a26e5e3fbd

Handling imaginary numbers

Consider the following:

import sympy, sympytorch, torch
x = sympy.symbols("x")
f = sympy.Functions("f")(x)
f = x * sympy.I

sympytorch.SymPyModule(expressions=[f])

I get the following error, how can I make it work with complex function?

In [16]: _calculate_symbolic_funcs()
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:97, in _Node.__init__(self, expr, _memodict, _func_lookup, **kwargs)
     96 try:
---> 97     arg_ = _memodict[arg]
     98 except KeyError:

KeyError: I

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
Cell In[16], line 1
----> 1 _calculate_symbolic_funcs()

Cell In[15], line 21, in _calculate_symbolic_funcs()
     19 funcs = [i.subs({theta: sympy.acos(costheta)}) for i in funcs]
     20 orig_funcs = [sympy.simplify(i).evalf() for i in funcs]
---> 21 results = sympytorch.SymPyModule(expressions=orig_funcs, extra_funcs={"conjugate": _tconjugate})
     22 return results

File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:152, in SymPyModule.__init__(self, expressions, extra_funcs, **kwargs)
    148 _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
    150 _memodict = {}
    151 self._nodes = torch.nn.ModuleList(
--> 152     [_Node(expr=expr, _memodict=_memodict, _func_lookup=_func_lookup) for expr in expressions]
    153 )
    154 self._expressions_string = str(expressions)

File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:152, in <listcomp>(.0)
    148 _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
    150 _memodict = {}
    151 self._nodes = torch.nn.ModuleList(
--> 152     [_Node(expr=expr, _memodict=_memodict, _func_lookup=_func_lookup) for expr in expressions]
    153 )
    154 self._expressions_string = str(expressions)

File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:99, in _Node.__init__(self, expr, _memodict, _func_lookup, **kwargs)
     97     arg_ = _memodict[arg]
     98 except KeyError:
---> 99     arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
    100     _memodict[arg] = arg_
    101 args.append(arg_)

File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:93, in _Node.__init__(self, expr, _memodict, _func_lookup, **kwargs)
     91     self._args = ((lambda memodict: memodict[expr.name]),)
     92 else:
---> 93     self._torch_func = _func_lookup[expr.func]
     94     args = []
     95     for arg in expr.args:

File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/collections/__init__.py:986, in ChainMap.__getitem__(self, key)
    984     except KeyError:
    985         pass
--> 986 return self.__missing__(key)

File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/collections/__init__.py:978, in ChainMap.__missing__(self, key)
    977 def __missing__(self, key):
--> 978     raise KeyError(key)

KeyError: <class 'sympy.core.numbers.ImaginaryUnit'>

Logic expression

What a nice tools! However, current codes has trouble dealing with logic expression with more than 2 terms, for example:

x, y, z = sympy.symbols("x,y,z")
expr =x  | y | z 
model = SymPyModule(expressions=[expr])
terms = {
        "x": torch.randint(0, 2, (10,)).bool(),
        "y": torch.randint(0, 2, (10,)).bool(),
        "z": torch.randint(0, 2, (10,)).bool(),
}
result = model(**terms)

Then it will raise an error:

TypeError: logical_or() takes 2 positional arguments but 3 were given

Here's the solution: modify

sympy.And: torch.logical_and,
sympy.Or: torch.logical_or,

to

sympy.And: _reduce(torch.logical_and),
sympy.Or: _reduce(torch.logical_or),

, which is the same as mul and add.

Also, besides the modification above, I suggest two more improvements:

  1. make this function public
    def _reduce(fn):
    def fn_(*args):
    return ft.reduce(fn, args)
    return fn_
  2. modify
    _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)

    to
_func_lookup = co.ChainMap(extra_funcs, _global_func_lookup)

so that users can overload default lookup tables.

factorisation subpaterns, float compatible, inplace and memory optimisation....

Hi, I've also coded a function equivalent to this one on my side.
Thank you very much for publishing your project as open source.
One of the problems with open-source is its the diversification.
Our 2 projects seem very complementary, as they both have different strengths. Yours is more torch-friendly with elegant integration, mine is more optimized on the points described in the titles. I'd be happy to merge my project with yours. That way, I can delete mine. It's better to have one single library that does everything well than 2 that do everything halfway!

Here's the source code for my project, which is perhaps a little more "sympy" than "torch" oriented:
https://framagit.org/robinechuca/cutcutcodec/-/blob/main/cutcutcodec/core/compilation/sympy_to_torch.py

Are you interested to implement some of this optimizations in your project?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.