ami-iit / adam

adam implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
https://adam-docs.readthedocs.io/en/latest/
BSD 3-Clause "New" or "Revised" License
123 stars 19 forks source link

Allow pytorch batching using jax2torch #93

Closed Giulero closed 2 months ago

Giulero commented 2 months ago

This PR exploits the package jax2torch, based on the gist which explains how to convert jax functions in pytorch ones, preserving also the gradients.

This should allow to perform batch computations, e.g.

H_b_batch = torch.tile(torch.tensor(H_b), (n_samples, 1, 1)).requires_grad_()
joints_val_batch = torch.tile(torch.tensor(joints_val), (n_samples, 1)).requires_grad_()

mass_matrix = kindyn.mass_matrix(H_b_batch, joints_val_batch)
# The mass matrix dimension should have dimension (n_samples, n_dofs + 6, n_dofs + 6)

# and gradient
mass_matrix.sum().backward()

I put this interface in adam.pytorch as KinDynComputionBatch, even if it using jax under the hood.

from adam.pytorch import KinDynComputationsBatch
# and it is instantiated in the vary same way of the other interfaces
kindyn = KinDynComputationsBatch(model_path, joints_name_list)

@traversaro already opened https://github.com/conda-forge/staged-recipes/pull/26780#issuecomment-2194579078, so everything will be soon conda ready.

traversaro commented 2 months ago

@traversaro already opened conda-forge/staged-recipes#26780 (comment), so everything will be soon conda ready.

Merged: https://github.com/conda-forge/jax2torch-feedstock .

Giulero commented 2 months ago

All tests are passing now!

Giulero commented 2 months ago

Cc @evelyd @Zweisteine96