It would be nice to unify the backends as done for CasADi and NumPy.
I had to move some methods in the Jax computation class and reimplement them due to the immutable Jax types (no slice assign operator []). The PyTorch implementation has a similar issue, since I had to cast some vectors in torch,tensor.
Probably subclassing Jax and PyTorch and create a more abstract class in which some methods are redefined (for example __setitem__ for handle the immutable types of jax, or a general vector type that casts array and list in torch.tensor)
It would be nice to unify the backends as done for CasADi and NumPy. I had to move some methods in the Jax computation class and reimplement them due to the immutable Jax types (no slice assign operator
[]
). The PyTorch implementation has a similar issue, since I had to cast some vectors intorch,tensor
. Probably subclassing Jax and PyTorch and create a more abstract class in which some methods are redefined (for example__setitem__
for handle the immutable types of jax, or a generalvector
type that casts array and list intorch.tensor
)