based-robotics / jaxadi

Transforms your CasADi functions into batchable JAX-compatible functions. By combining the power of CasADi with the flexibility of JAX, JAXADi enables the creation of efficient code that runs smoothly on CPUs, GPUs, and TPUs.
https://based-robotics.github.io/jaxadi/
MIT License
94 stars 2 forks source link

Dimensions bug #8

Closed simeon-ned closed 3 days ago

simeon-ned commented 1 week ago

Calling the convert on the function whose arguments have different shapes result in error: ValueError: All input arrays must have the same shape. The code that reproduce this bug:

import casadi as cs
from jaxadi import convert

x = cs.SX.sym("x", 3, 10)
y = cs.SX.sym("y", 10, 2)
casadi_fn = cs.Function("myfunc", [x, y], [x @ y])
jax_fn = convert(casadi_fn, compile=True)