Closed Giulero closed 2 months ago
@traversaro already opened conda-forge/staged-recipes#26780 (comment), so everything will be soon conda ready.
Merged: https://github.com/conda-forge/jax2torch-feedstock .
All tests are passing now!
Cc @evelyd @Zweisteine96
This PR exploits the package jax2torch, based on the gist which explains how to convert
jax
functions inpytorch
ones, preserving also the gradients.This should allow to perform batch computations, e.g.
I put this interface in
adam.pytorch
asKinDynComputionBatch
, even if it usingjax
under the hood.@traversaro already opened https://github.com/conda-forge/staged-recipes/pull/26780#issuecomment-2194579078, so everything will be soon conda ready.