patrick-kidger / sympytorch Goto Github PK
View Code? Open in Web Editor NEWTurning SymPy expressions into PyTorch modules.
License: Apache License 2.0
Turning SymPy expressions into PyTorch modules.
License: Apache License 2.0
Consider the following:
import sympy, sympytorch, torch
x = sympy.symbols("x")
f = sympy.Functions("f")(x)
f = x * sympy.I
sympytorch.SymPyModule(expressions=[f])
I get the following error, how can I make it work with complex function?
In [16]: _calculate_symbolic_funcs()
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:97, in _Node.__init__(self, expr, _memodict, _func_lookup, **kwargs)
96 try:
---> 97 arg_ = _memodict[arg]
98 except KeyError:
KeyError: I
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
Cell In[16], line 1
----> 1 _calculate_symbolic_funcs()
Cell In[15], line 21, in _calculate_symbolic_funcs()
19 funcs = [i.subs({theta: sympy.acos(costheta)}) for i in funcs]
20 orig_funcs = [sympy.simplify(i).evalf() for i in funcs]
---> 21 results = sympytorch.SymPyModule(expressions=orig_funcs, extra_funcs={"conjugate": _tconjugate})
22 return results
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:152, in SymPyModule.__init__(self, expressions, extra_funcs, **kwargs)
148 _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
150 _memodict = {}
151 self._nodes = torch.nn.ModuleList(
--> 152 [_Node(expr=expr, _memodict=_memodict, _func_lookup=_func_lookup) for expr in expressions]
153 )
154 self._expressions_string = str(expressions)
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:152, in <listcomp>(.0)
148 _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
150 _memodict = {}
151 self._nodes = torch.nn.ModuleList(
--> 152 [_Node(expr=expr, _memodict=_memodict, _func_lookup=_func_lookup) for expr in expressions]
153 )
154 self._expressions_string = str(expressions)
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:99, in _Node.__init__(self, expr, _memodict, _func_lookup, **kwargs)
97 arg_ = _memodict[arg]
98 except KeyError:
---> 99 arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
100 _memodict[arg] = arg_
101 args.append(arg_)
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/site-packages/sympytorch/sympy_module.py:93, in _Node.__init__(self, expr, _memodict, _func_lookup, **kwargs)
91 self._args = ((lambda memodict: memodict[expr.name]),)
92 else:
---> 93 self._torch_func = _func_lookup[expr.func]
94 args = []
95 for arg in expr.args:
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/collections/__init__.py:986, in ChainMap.__getitem__(self, key)
984 except KeyError:
985 pass
--> 986 return self.__missing__(key)
File /opt/mambaforge/mambaforge/envs/m3gnet/lib/python3.10/collections/__init__.py:978, in ChainMap.__missing__(self, key)
977 def __missing__(self, key):
--> 978 raise KeyError(key)
KeyError: <class 'sympy.core.numbers.ImaginaryUnit'>
Then have the return value from SymPyModule.forward
be a corresponding dictionary. That way we can both named inputs and named outputs.
I've been playing around and came across this today
it's possible this repo could take things one step further.
can you can train chat models with your own data - using pinecone -
it would probably be possible to have similar results to below without too much effort but with sympytorch.
https://arxiv.org/pdf/2307.16171.pdf
ask question
can you provide latex equations from doc?
Pages 1โ2: Equations, Latex, summarization
Here are some latex equations from the document:
Equation (1):
$\[
\log p_{\theta}(x|c) \geq \mathbb{E}_{q_{\phi}(z|x)} \left[ \log p_{\theta_d}(x|z_a) - \log q_{\phi_a}(z_a|x_{\text{spec}}) + \log p_{\theta_a}(z_a|z_l) - \log q_{\phi_l}(z_l|x_{w2v}) \right]
\]$
Equation (2):
$\[
L_{\text{pitch}} = \|\log p_x - G_s(z_a, s)\|_1
\]$
Equation (3):
$\[
L_{\text{STFT}} = \|\psi(x) - \psi(G_w(z_a, p_h, s))\|_1
\]$
Equation (4):
$\[
L_{\text{adv}}(D) = \mathbb{E}_{x,z_a} \left[ (D(x)-1)^2 + D(G(z_a, s))^2 \right]
\]$
Equation (5):
$\[
L_{\text{adv}}(\phi_a, \theta_d) = \mathbb{E}_{z_a} \left[ (D(G(z_a, s))-1)^2 \right]
\]$
chatgpt
can you provide some sympy code for following equations - with some dummy data - Equation (1):
Sure, I can provide SymPy code for each of the equations with some dummy data. Note that SymPy is mainly used for symbolic mathematics, so we'll be creating symbolic representations of the equations rather than solving them numerically.
Before we start, let's define the symbols used in the equations:
x
: Data variable.c
: Some context.p_theta
, p_theta_d
, p_theta_a
: Probability distributions with parameters theta, theta_d, and theta_a, respectively.q_phi
, q_phi_a
, q_phi_l
: Probability distributions with parameters phi, phi_a, and phi_l, respectively.z
: Latent variable.z_a
: Latent variable related to the attribute a.z_l
: Latent variable related to the label l.x_spec
, x_w2v
: Some specific data or context for x and w2v (word2vec) models, respectively.L_pitch
, L_STFT
, L_adv
: Different loss terms.G_s
, G_w
: Functions that take some inputs and generate data.Now let's define the equations in SymPy:
import sympy as sp
# Define symbols
x, c, z, z_a, z_l, x_spec, x_w2v, p_theta, p_theta_d, p_theta_a = sp.symbols('x c z z_a z_l x_spec x_w2v p_theta p_theta_d p_theta_a')
q_phi, q_phi_a, q_phi_l = sp.symbols('q_phi q_phi_a q_phi_l')
s, p_h = sp.symbols('s p_h')
D = sp.symbols('D', cls=sp.Function)
# Equation (1)
eq1 = sp.log(p_theta(x|c)) >= sp.E**sp.integrate(sp.log(p_theta_d(x|z_a)) - sp.log(q_phi_a(z_a|x_spec)) + sp.log(p_theta_a(z_a|z_l)) - sp.log(q_phi_l(z_l|x_w2v)), (z, q_phi(z|x)))
# Equation (2)
G_s = sp.Function('G_s')
L_pitch = sp.norm(sp.log(x) - G_s(z_a, s), 1)
# Equation (3)
G_w = sp.Function('G_w')
psi = sp.Function('psi')
L_STFT = sp.norm(psi(x) - psi(G_w(z_a, p_h, s)), 1)
# Equation (4)
L_adv_D = sp.E**sp.integrate((D(x) - 1)**2 + D(G(z_a, s))**2, (x, z_a))
# Equation (5)
L_adv_phi_theta = sp.E**sp.integrate((D(G(z_a, s)) - 1)**2, (z_a, q_phi_a(z_a|x_spec)))
# Print equations
print("Equation (1):")
sp.pprint(eq1)
print("\nEquation (2):")
sp.pprint(L_pitch)
print("\nEquation (3):")
sp.pprint(L_STFT)
print("\nEquation (4):")
sp.pprint(L_adv_D)
print("\nEquation (5):")
sp.pprint(L_adv_phi_theta)
Please note that these equations will not be solvable in the usual sense due to the presence of integrals and probability distributions with unknown functional forms. But the above code will help you create symbolic representations of the equations for further manipulations or analyses in SymPy.
I did another one with physics -- you can prompt gpt to dream up applicability / dummy values. here I gave it direct equation - and ask it to explain it in children terms.
https://gist.github.com/johndpope/ab1f334b615222b81505b6a26e5e3fbd
Hi, I've also coded a function equivalent to this one on my side.
Thank you very much for publishing your project as open source.
One of the problems with open-source is its the diversification.
Our 2 projects seem very complementary, as they both have different strengths. Yours is more torch-friendly with elegant integration, mine is more optimized on the points described in the titles. I'd be happy to merge my project with yours. That way, I can delete mine. It's better to have one single library that does everything well than 2 that do everything halfway!
Here's the source code for my project, which is perhaps a little more "sympy" than "torch" oriented:
https://framagit.org/robinechuca/cutcutcodec/-/blob/main/cutcutcodec/core/compilation/sympy_to_torch.py
Are you interested to implement some of this optimizations in your project?
What a nice tools! However, current codes has trouble dealing with logic expression with more than 2 terms, for example:
x, y, z = sympy.symbols("x,y,z")
expr =x | y | z
model = SymPyModule(expressions=[expr])
terms = {
"x": torch.randint(0, 2, (10,)).bool(),
"y": torch.randint(0, 2, (10,)).bool(),
"z": torch.randint(0, 2, (10,)).bool(),
}
result = model(**terms)
Then it will raise an error:
TypeError: logical_or() takes 2 positional arguments but 3 were given
Here's the solution: modify
sympytorch/sympytorch/sympy_module.py
Lines 52 to 53 in ca3e3f4
sympy.And: _reduce(torch.logical_and),
sympy.Or: _reduce(torch.logical_or),
, which is the same as mul
and add
.
Also, besides the modification above, I suggest two more improvements:
sympytorch/sympytorch/sympy_module.py
Lines 7 to 10 in ca3e3f4
sympytorch/sympytorch/sympy_module.py
Line 158 in ca3e3f4
_func_lookup = co.ChainMap(extra_funcs, _global_func_lookup)
so that users can overload default lookup tables.
Hey! thanks a lot for this package!:)
I just rand into what I think might be a minor bug:
In [1]: from sympytorch import SymPyModule
In [2]: import sympy
# this works
In [6]: m = SymPyModule(expressions=[sympy.core.numbers.Float(2.0)])
In [7]: m()
Out[7]: tensor([2.], grad_fn=<StackBackward0>)
# this does not
In [4]: m = SymPyModule(expressions=[sympy.core.numbers.Zero()])
In [5]: m()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 m()
File ~/Library/Application Support/hatch/env/virtual/qucint/6NOL9orC/qucint/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Library/Application Support/hatch/env/virtual/qucint/6NOL9orC/qucint/lib/python3.9/site-packages/sympytorch/sympy_module.py:191, in SymPyModule.forward(self, **symbols)
189 def forward(self, **symbols):
190 out = [node(symbols) for node in self._nodes]
--> 191 out = torch.broadcast_tensors(*out)
192 return torch.stack(out, dim=-1)
File ~/Library/Application Support/hatch/env/virtual/qucint/6NOL9orC/qucint/lib/python3.9/site-packages/torch/functional.py:74, in broadcast_tensors(*tensors)
72 if has_torch_function(tensors):
73 return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 74 return _VF.broadcast_tensors(tensors)
TypeError: expected Tensor as element 0 in argument 0, but got int
If you have an idea where to start looking I could also take a stab at it.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.