patrick-kidger / sympy2jax

Turn SymPy expressions into trainable JAX expressions.
Apache License 2.0
308 stars 12 forks source link

sympy.Array support #11

Open lmriccardo opened 11 months ago

lmriccardo commented 11 months ago

It would be possible to add support for sympy.Array objects? It might be useful in case like:

from sympy import Array
from sympy.abc import x
from sympy2jax import SymbolicModule

a = Array([1,2,3])
e = a * x
j = SymbolicModule(e)
j(x=2)

# Output --> jax.DeviceArray([2, 4, 6], dtype=int64)

Obviously, when there is no the possibility to give an array directly as "subs" to the SymbolicModule.

Thanks you for the answer!

patrick-kidger commented 11 months ago

I'd be happy to take PR (with tests) on this!