ami-iit / adam

adam implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
https://adam-docs.readthedocs.io/en/latest/
BSD 3-Clause "New" or "Revised" License
131 stars 20 forks source link

Unify Jax and PyTorch backends #5

Closed Giulero closed 2 years ago

Giulero commented 2 years ago

It would be nice to unify the backends as done for CasADi and NumPy. I had to move some methods in the Jax computation class and reimplement them due to the immutable Jax types (no slice assign operator []). The PyTorch implementation has a similar issue, since I had to cast some vectors in torch,tensor. Probably subclassing Jax and PyTorch and create a more abstract class in which some methods are redefined (for example __setitem__ for handle the immutable types of jax, or a general vector type that casts array and list in torch.tensor)