ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
69 stars 10 forks source link

Speed up passing different `JaxSimModel` with same pytree structure to JIT-compiled functions #179

Closed diegoferigo closed 3 months ago

diegoferigo commented 3 months ago

PR #173 changed how JaxSimModel objects are hashed and compared. It was a necessary change since our frame-related data were read from JaxSimModel.description that is a Static attribute. Before that PR, the hash of the description was explicitly ignored, but that caused problems when different models were passed to the same JIT-compiled function because JAX was not taking into account possible differences in frame data. The outcome was that the frames test was not passing, and the solution was to make JaxSimModel.description properly hashable.

These days I played a bit around with different JaxSimModels (e.g. parameterized models), and I discovered that if there are two exact instances model1 and model2 created from the same URDF (therefore, not having the second model directly copied from the first one), calling a JIT-compiled function on model was extremely slow even if JIT-recompilation was not triggered. The problem is that in these cases, JAX computes the hash of static attributes, and right now the hash of ModelDescription takes hundreds of milliseconds. This is a problem because the processing of static attributes is orders of magnitude longer than the actual compiled computation.

Originally, I though ways to speed up the equality computation of ModelDescription, and this PR contains new __eq__ methods that do not call __hash__.

However, then I realized that the best solution is to move also frame-related data in KinDynParameters, similarly to what we already do for links, joints, and contacts. In this way, JaxSimModel.description can be ignored again, and there's no longer the need to compute its hash. This speeds up significantly calling function that have been JIT-compiled on model1 using model2.

There still is an overhead. I guess that JAX saves the id of model1 and skips the check on static attributes, check that instead is done for `model2. The following is the runtime on CPU using two full ErgoCub models:

import jaxsim.api as js
import resolve_robotics_uri_py
from jaxsim import VelRepr

urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
    uri="model://ergoCubSN001/model.urdf"
)

model1 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_path,
    is_urdf=True,
)

model2 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_path,
    is_urdf=True,
)

data = js.data.random_model_data(model=model1, velocity_representation=VelRepr.Mixed)

# First run for JIT compilation, second one for runtime.
%time _ = js.model.forward_dynamics_aba(model1, data)  # Wall time: 4.31 s
%timeit js.model.forward_dynamics_aba(model1, data)    # 296 µs ± 12.6 µs

# This should not get JIT compiled.
%time _ = js.model.forward_dynamics_aba(model2, data)  # Wall time: 1.56 ms

# This should go almost as fast as on model1.
# The overhead is due to the comparison of static attributes.
%timeit js.model.forward_dynamics_aba(model2, data)  # 883 µs ± 14.2 µs

Note: this PR goes also in the direction of exploting the on-disk JAX compilation cache supported by the GPU and TPU backends. In the past, I've never managed to make it work with JaxSim, probably due to factors related to this PR. Now things should be better since hash and equality are correct. What I suspect is still missing is wrapping static strings with a custom hash function since by default the hash of a string in Python is not the same among executions.


📚 Documentation preview 📚: https://jaxsim--179.org.readthedocs.build//179/

diegoferigo commented 3 months ago

Thanks @flferretti for the suggestion on more compact syntax. The original idea was to prevent evaluating all the conditions if one of the first ones is false. Your suggestion using all would work similarly only if its argument is a generator expression, which is not the case. I updated your suggestion to use chained and, Python stops early also in this case.