Transforms your CasADi functions into batchable JAX-compatible functions. By combining the power of CasADi with the flexibility of JAX, JAXADi enables the creation of efficient code that runs smoothly on CPUs, GPUs, and TPUs.
Calling the convert on the function whose arguments have different shapes result in error:
ValueError: All input arrays must have the same shape.
The code that reproduce this bug:
import casadi as cs
from jaxadi import convert
x = cs.SX.sym("x", 3, 10)
y = cs.SX.sym("y", 10, 2)
casadi_fn = cs.Function("myfunc", [x, y], [x @ y])
jax_fn = convert(casadi_fn, compile=True)
Calling the
convert
on the function whose arguments have different shapes result in error:ValueError: All input arrays must have the same shape.
The code that reproduce this bug: