ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
86 stars 11 forks source link

Document how the integrator deals with the system dynamics in jaxsim #304

Open CarlottaSartore opened 1 day ago

CarlottaSartore commented 1 day ago

While dealing with https://github.com/ami-iit/jaxcon/issues/1 I needed to extract the system acceleration and the contact forces from jaxsim.

I had a few issues understanding how the step method worked and extracting those quantities.

In this issue, I want to document what I find out and try to identify how to improve the readability of the code.

CarlottaSartore commented 1 day ago

With the help of @xela-95 I was able to retrieve how the integration was performed

The system integration is performed as follows:

  1. We call the step method inside model: https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/api/model.py#L2135
  2. The step method calls an integrator.step method https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/api/model.py#L2239-L2260
  3. The integrator step is a method of the integrator abstract class where the integrator class itself is called in called https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/integrators/common.py#L80-L114
  4. The __call__ method is an abstract method, let's see how it is implemented in runge kutta to have an example, it calls the _compute_next_state function
    https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/integrators/common.py#L228-L240
  5. We have than a call to f which represent the system dynamic which is called in a scan which calls a lax conditioned ( to condition the call to initial condition ) https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/integrators/common.py#L358
  6. f icloses an object of system dynamics on specific args https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/integrators/common.py#L319
  7. self.dynamics is an object of system dynamics https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/integrators/common.py#L177
  8. system dynamics is a protocol https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/integrators/common.py#L38-L41

Then to understand where the specific system dynamics of the robot is defined, we have to go to the initial definition of the integral:

  1. we define the system dynamic as an ode state https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/api/model.py#L264-L269
  2. The ode is wrapped as a system dynamic state https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/api/ode.py#L26-L80
  3. Then the system dynamic function computes the state to be integrated https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/api/ode.py#L374-L452
  4. Then the ode state translates into the input to be integrated https://github.com/ami-iit/jaxsim/blob/main/src/jaxsim/api/ode_data.py#L541
CarlottaSartore commented 1 day ago

After this analysis, it was still unclear how to extract the acceleration and forces information. Indeed, the computation of the system acceleration (that needs to be then integrated via the integrator) is done under the hood by the integrator and done for different integrator state $y$.

I think there are two points now:

Do you have any idea/ suggestion ? @ami-iit/darwin

flferretti commented 16 hours ago

Hi @CarlottaSartore, thank you for this analysis! Indeed, the integrators is quite intricated and can easily lead to confusion. I'll try to answer to each of the issue points.

For what regards:

After this analysis, it was still unclear how to extract the acceleration and forces information.

The acceleration $\dot{\nu} \in \mathbb{R} ^{6 + N}$ can be computed, but not extracted, from either https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/api/model.py#L884-L891 or, if you want to take into consideration the contact dynamics, and the joint limits and frictions, from https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/api/ode.py#L88-L95

How can we simplify the code such that it is easier to read ?

The good thing about having protocols instead of abstract classes is that they help decouple the interface from the implementation, as we don't actually need a class, but a single function the computes the system dynamics. Nevertheless, I'm not against using abstract classes

How can we export the joint acceleration and contact wrenches ? I think we can assume that we want them computed at the final time fo the integration, so that all the information (joint position, velocity etc) are associated to the same time step, one idea could be to re-call the system_velocity_dynamic with the specific input, but we would need to wrap around other method and it will still not be the simulator info

Regarding this, I currently have two ideas:

  1. Modify js.model.step so that is returns additional info. This will allow to access information regarding the system acceleration, but at the same time we would need to propagate some info that are currently inside the code
  2. Separate the dynamics and the integration in two additional methods. This would allow to access intermediate information before manually integrating the dynamics. We could still keep the step function, while adding js.model.forward and js.model.integrate. The first function will compute the complete dynamics of the system returning an object of ODEState with the derivatives of the state variables, while the second one will simply call integrator.step in a completely functional way by modifying https://github.com/ami-iit/jaxsim/blob/d2cf2206ef3962ca69d67c191ebeea86d6cc5d87/src/jaxsim/integrators/common.py#L80-L88 so that is accepts a system_dynamics kwargs:
    def step(
        self,
        x0: State,
        t0: Time,
        dt: TimeStep,
        *,
+       system_dynamics: SystemDynamics | None = None
        metadata: dict[str, Any] | None = None,
        **kwargs,
    ) -> tuple[NextState, dict[str, Any]]:
        # [...]

So that system_dynamics will be passed in kwargs to RungeKutta.__call__ and then finally propagated to _compute_next_state:

    def _compute_next_state(
        self, x0: State, t0: Time, dt: TimeStep, **kwargs
    ) -> tuple[NextState, dict[str, Any]]:
        """
        Compute the next state of the system, returning all the output states.

        Args:
            x0: The initial state of the system.
            t0: The initial time of the system.
            dt: The time step of the integration.
            **kwargs: Additional keyword arguments.

        Returns:
            A batched state with as many batch elements as `b.T` rows.
        """

        # Call variables with better symbols.
        Δt = dt
        c = self.c
        b = self.b
        A = self.A

+       dynamics = system_dynamics if system_dynamics is not None or self.dynamics

        # Close f over optional kwargs.
-       f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
+       f = lambda x, t: dynamics(x=x, t=t, **kwargs)

        # [...]