Open flferretti opened 3 days ago
Just to understand, did you tried if the same problem happens if we are using decorators over using jax.jit
function directly? I would expect the use of decorator and functions to do the same, but you know that the devil is in the details.
That's a good observation! Yet, the behavior remains:
jax.jit
decorator:
import jax
def fn(x: int, y:int):
def multiply(x:int, y:int):
return x * y
return multiply(x,y)
with jax.log_compiles() jax.jit(fn)(4,3)
Finished tracing + transforming multiply for pjit in 0.00031876564025878906 sec Finished tracing + transforming fn for pjit in 0.0007829666137695312 sec Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue]. Finished jaxpr to MLIR module conversion jit(fn) in 0.002622365951538086 sec Finished XLA compilation of jit(fn) in 0.04105210304260254 sec
- With `jax.jit` decorator:
```python
import jax
def fn(x: int, y:int):
@jax.jit
def multiply(x:int, y:int):
return x * y
return multiply(x,y)
with jax.log_compiles()
jax.jit(fn)(4,3)
Finished tracing + transforming multiply for pjit in 0.0002593994140625 sec
Finished tracing + transforming multiply for pjit in 0.0005724430084228516 sec
Finished tracing + transforming fn for pjit in 0.0009195804595947266 sec
Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: [UnspecifiedValue, UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(fn) in 0.002263784408569336 sec
Finished XLA compilation of jit(fn) in 0.04090547561645508 sec
As you can see, the function multiply
is compiled twice
Check out the branch https://github.com/ami-iit/jaxsim/compare/main...remove_jit for testing
Watch out that in all your examples you are running the jit transformation on temporary objects. Try to decorate the outer function instead (maybe it goes aumatically to the cache the first run, but worth checking).
Generally speaking, the need to have jit decorators on all APIs is because 1) people that use the project interactively do not need to remember to apply (and understand how to do it) the jit transformation; 2) it's the simplest way to define static function arguments, this cannot be left to the user as it introduces an additional burden.
Removing only the decorators that do not use partial may introduce asymmetries in the APIs.
Currently, we are decorating every method and function inside the
jaxsim.api
withjax.jit
. Yet, this introduces an overhead as the inner functions get compiled multiple times:While this can be nice when using JaxSim a multibody dynamics library, it can lead to unexpected result or additional overhead that could be removed.
FYI @traversaro @CarlottaSartore @diegoferigo @xela-95