Closed diegoferigo closed 3 months ago
Thanks @flferretti for the suggestion on more compact syntax. The original idea was to prevent evaluating all the conditions if one of the first ones is false. Your suggestion using all
would work similarly only if its argument is a generator expression, which is not the case. I updated your suggestion to use chained and
, Python stops early also in this case.
PR #173 changed how
JaxSimModel
objects are hashed and compared. It was a necessary change since our frame-related data were read fromJaxSimModel.description
that is aStatic
attribute. Before that PR, the hash of the description was explicitly ignored, but that caused problems when different models were passed to the same JIT-compiled function because JAX was not taking into account possible differences in frame data. The outcome was that the frames test was not passing, and the solution was to makeJaxSimModel.description
properly hashable.These days I played a bit around with different
JaxSimModels
(e.g. parameterized models), and I discovered that if there are two exact instancesmodel1
andmodel2
created from the same URDF (therefore, not having the second model directly copied from the first one), calling a JIT-compiled function onmodel
was extremely slow even if JIT-recompilation was not triggered. The problem is that in these cases, JAX computes the hash of static attributes, and right now the hash ofModelDescription
takes hundreds of milliseconds. This is a problem because the processing of static attributes is orders of magnitude longer than the actual compiled computation.Originally, I though ways to speed up the equality computation of
ModelDescription
, and this PR contains new__eq__
methods that do not call__hash__
.However, then I realized that the best solution is to move also frame-related data in
KinDynParameters
, similarly to what we already do for links, joints, and contacts. In this way,JaxSimModel.description
can be ignored again, and there's no longer the need to compute its hash. This speeds up significantly calling function that have been JIT-compiled onmodel1
usingmodel2
.There still is an overhead. I guess that JAX saves the id of
model1
and skips the check on static attributes, check that instead is done for `model2. The following is the runtime on CPU using two full ErgoCub models:Note: this PR goes also in the direction of exploting the on-disk JAX compilation cache supported by the GPU and TPU backends. In the past, I've never managed to make it work with JaxSim, probably due to factors related to this PR. Now things should be better since hash and equality are correct. What I suspect is still missing is wrapping static strings with a custom hash function since by default the hash of a string in Python is not the same among executions.
📚 Documentation preview 📚: https://jaxsim--179.org.readthedocs.build//179/