ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
84 stars 11 forks source link

Give the user the possibility of choosing what to JIT compile #301

Open flferretti opened 3 days ago

flferretti commented 3 days ago

Currently, we are decorating every method and function inside the jaxsim.api with jax.jit. Yet, this introduces an overhead as the inner functions get compiled multiple times:

Single JIT Multiple JIT
```python >>> import jax >>> >>> def fn(x: int, y:int): >>> return x * y >>> >>> with jax.log_compiles(): >>> jax.jit(fn)(4, 3) Finished tracing + transforming multiply for pjit in 0.0005085468292236328 sec Finished tracing + transforming fn for pjit in 0.0011935234069824219 sec Compiling fn for with global shapes and types [ ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True) ]. Argument mapping: [UnspecifiedValue, UnspecifiedValue]. Finished jaxpr to MLIR module conversion jit(fn) in 0.0018908977508544922 sec Finished XLA compilation of jit(fn) in 0.03702688217163086 sec ``` ```python >>> import jax >>> >>> def fn(x: int, y:int): >>> return x * y >>> >>> with jax.log_compiles(): >>> jax.jit(jax.jit(jax.jit(jax.jit(jax.jit(fn)))))(4, 3) Finished tracing + transforming multiply for pjit in 0.0003635883331298828 sec Finished tracing + transforming fn for pjit in 0.0006585121154785156 sec Finished tracing + transforming fn for pjit in 0.0009508132934570312 sec Finished tracing + transforming fn for pjit in 0.0011334419250488281 sec Finished tracing + transforming fn for pjit in 0.0013675689697265625 sec Finished tracing + transforming fn for pjit in 0.0017311573028564453 sec Compiling fn for with global shapes and types [ ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True) ]. Argument mapping: [UnspecifiedValue, UnspecifiedValue]. Finished jaxpr to MLIR module conversion jit(fn) in 0.0028264522552490234 sec Finished XLA compilation of jit(fn) in 0.040222883224487305 sec ```

While this can be nice when using JaxSim a multibody dynamics library, it can lead to unexpected result or additional overhead that could be removed.

FYI @traversaro @CarlottaSartore @diegoferigo @xela-95

traversaro commented 3 days ago

Just to understand, did you tried if the same problem happens if we are using decorators over using jax.jit function directly? I would expect the use of decorator and functions to do the same, but you know that the devil is in the details.

flferretti commented 3 days ago

That's a good observation! Yet, the behavior remains:

def fn(x: int, y:int):

def multiply(x:int, y:int):
    return x * y

return multiply(x,y)

with jax.log_compiles() jax.jit(fn)(4,3)

Finished tracing + transforming multiply for pjit in 0.00031876564025878906 sec Finished tracing + transforming fn for pjit in 0.0007829666137695312 sec Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue]. Finished jaxpr to MLIR module conversion jit(fn) in 0.002622365951538086 sec Finished XLA compilation of jit(fn) in 0.04105210304260254 sec


- With `jax.jit` decorator:

```python
import jax

def fn(x: int, y:int):

    @jax.jit
    def multiply(x:int, y:int):
        return x * y

    return multiply(x,y)

with jax.log_compiles()
    jax.jit(fn)(4,3)

Finished tracing + transforming multiply for pjit in 0.0002593994140625 sec
Finished tracing + transforming multiply for pjit in 0.0005724430084228516 sec
Finished tracing + transforming fn for pjit in 0.0009195804595947266 sec
Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) in 0.002263784408569336 sec
Finished XLA compilation of jit(fn) in 0.04090547561645508 sec

As you can see, the function multiply is compiled twice

flferretti commented 3 days ago

Check out the branch https://github.com/ami-iit/jaxsim/compare/main...remove_jit for testing

diegoferigo commented 2 days ago

Watch out that in all your examples you are running the jit transformation on temporary objects. Try to decorate the outer function instead (maybe it goes aumatically to the cache the first run, but worth checking).

diegoferigo commented 2 days ago

Generally speaking, the need to have jit decorators on all APIs is because 1) people that use the project interactively do not need to remember to apply (and understand how to do it) the jit transformation; 2) it's the simplest way to define static function arguments, this cannot be left to the user as it introduces an additional burden.

Removing only the decorators that do not use partial may introduce asymmetries in the APIs.