ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
66 stars 9 forks source link

Performing a simulation enforcing JAX to run in 32-bit precision overflows after 4.29 seconds #136

Open diegoferigo opened 5 months ago

diegoferigo commented 5 months ago

JaxSim, by default, enforces JAX to run in 64-bit precision. This is necessary to execute our RBDAs and compare the results to third-party C++ libraries.

As explained in #109, executing JaxSim with 32-bit precision might result to a significant speedup, especially on hardware accelerators like GPUs. Furthermore, TPUs only support 32-bits (https://github.com/google/jax/issues/9862), therefore we need to ensure that JaxSim runs with this precision if we want to exploit this hardware.

JaxSim is almost 32-bit compatible, the only caveats originate on how we handle the simulation time:

https://github.com/ami-iit/jaxsim/blob/10c719984c0b600afd165d65a25f215d8bc28a75/src/jaxsim/api/data.py#L42-L44

When running in 32-bit precision, time_ns is downcasted as jnp.uint32. Unfortunately, the range of this data type is problematic when it describes nanoseconds. In fact:

In [1]: jnp.iinfo(jnp.uint32)
Out[1]: iinfo(min=0, max=4294967295, dtype=uint32)

The maximum time we can simulate without overflowing is approximately 4.29 seconds. When running a simulation longer than that, the simulator hangs, requiring to press Ctrl+C.

Something else to notice is that a time is seconds represented as a 32-bit float will have the following limits:

In [2]: jnp.finfo(jnp.float32)
Out[2]: finfo(resolution=1e-06, min=-3.4028235e+38, max=3.4028235e+38, dtype=float32)

Therefore, in this case we cannot achieve a perfect granularity up to 1 nanosecond (it's pointless going under that for rigid body simulations, already 10-100 ns is small enough).


We should investigate alternative representations of the simulation time.

As a consequence, all functions accepting the current time should be ported to this new representation, that will be something more complex than a plain jax.Array:

diegoferigo commented 5 months ago

xref https://github.com/ami-iit/bipedal-locomotion-framework/pull/630. Note that JaxSim uses nanoseconds since the beginning (#1). The problem here is not handling time as nanoseconds, but handling time as nanoseconds stored in 32-bit types.

diegoferigo commented 5 months ago

The Python Standard Library has a decimal module that is pretty interesting for this kind of problems. Unfortunately, there is nothing comparable for either numpy of JAX. In order to be included in JaxSim and support jax.jit, we need to manipulate PyTree objects.

diegoferigo commented 5 months ago

A MWE of a Time class should at least have the following features:

flferretti commented 3 months ago

I was thinking of using something that leverages functools.total_ordering for the comparison operators. We could also use jax_dataclasses.pytree_dataclass to maintain immutability as follows:

@jax_dataclasses.pytree_dataclass
@total_ordering
class Time:
    nanoseconds: jnp.int64

    def __init__(self, nanoseconds):
        self.nanoseconds = jnp.int64(nanoseconds)

    def __add__(self, other):
        return Time(self.nanoseconds + other.nanoseconds)

    def __sub__(self, other):
        return Time(self.nanoseconds - other.nanoseconds)

    def __mod__(self, other):
        return Time(self.nanoseconds % other.nanoseconds)

    def __eq__(self, other):
        return self.nanoseconds == other.nanoseconds

    def __lt__(self, other):
        return self.nanoseconds < other.nanoseconds

    def advance(self, nanoseconds):
        return Time(self.nanoseconds + jnp.int64(nanoseconds))

    def __repr__(self):
        return f"Time({self.nanoseconds} ns)"
diegoferigo commented 3 months ago

This is a good idea, I'll have a look at functools.total_ordering for the comparison operators.

However, currently the main difficulty is in handling properly the overflow when JAX runs in 32-bits precision. I have a kind-of-working local implementation that I started probably 2 months ago in which seconds and nanoseconds are treated separately. Then the priorities shifted towards contact-related improvements, I'll have a look at this as soon as I find some spare time. Not sure if we have use cases on TPUs in the close future.

diegoferigo commented 2 months ago

In these days I realized that any option that replaces a single float with a class to handle the simulation time would introduce major challenges in applications that need to compute gradients with AD against $t$. I don't have yet a solution that makes me happy enough and does not introduce side-effects.