sympytorch
A micro-library as a convenience for turning SymPy expressions into PyTorch Modules.
All SymPy floats become trainable parameters. All SymPy symbols are inputs to the Module.
Installation
pip install git+https://github.com/patrick-kidger/sympytorch.git
Example
import sympy, torch, sympytorch
x = sympy.symbols('x_name')
cosx = 1.0 * sympy.cos(x)
sinx = 2.0 * sympy.sin(x)
mod = sympytorch.SymPyModule(expressions=[cosx, sinx])
x_ = torch.rand(3)
out = mod(x_name=x_) # out has shape (3, 2)
assert torch.equal(out[:, 0], x_.cos())
assert torch.equal(out[:, 1], 2 * x_.sin())
assert out.requires_grad # from the two Parameters initialised as 1.0 and 2.0
assert {x.item() for x in mod.parameters()} == {1.0, 2.0}
API
The API consists of a single object, SymPyModule
.
It is initialised as SymPyModule(*, expressions)
, where expressions
is a list of SymPy expressions.
It can be called, passing the values of the symbols as in the above example.
It has a method .sympy()
, which returns the corresponding list of SymPy expressions. (Which may not be the same as the expressions it was initialised with, if the values of its Parameters have been changed, i.e. have been learnt.)
Extensions
Not every PyTorch or SymPy operation is supported -- just the ones that I found I've needed! There's a dictionary here that lists the supported operations. Feel free to submit PRs for any extra operations you need.