ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
57 stars 9 forks source link

Impossible to run with JAX_DISABLE_JIT set to True model with zero dof #191

Open xela-95 opened 4 days ago

xela-95 commented 4 days ago

Related issue on JAX: https://github.com/google/jax/issues/4668

The error is:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /home/acroci/repos/component_alpha/rigid_contacts_analytical.py:11
      [7](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:7) integration_time = 0.001
      [9](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:9) representation = jaxsim.VelRepr.Mixed
---> [11](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:11) data = js.data.JaxSimModelData.build(
     [12](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:12)     model=model,
     [13](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:13)     velocity_representation=representation,  # standard_gravity=7.0
     [14](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:14) )
     [15](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:15) # integrator = integrators.fixed_step.RungeKutta4SO3.build(
     [16](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:16) # integrator = integrators.fixed_step.ForwardEuler.build(
     [17](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:17) integrator = integrators.fixed_step.ForwardEulerSO3.build(
     [18](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:18)     dynamics=js.ode.wrap_system_dynamics_for_integration(
     [19](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:19)         model=model,
   (...)
     [25](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:25)     ),
     [26](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/rigid_contacts_analytical.py:26) )

File ~/repos/jaxsim/src/jaxsim/api/data.py:186, in JaxSimModelData.build(model, base_position, base_quaternion, joint_positions, base_linear_velocity, base_angular_velocity, joint_velocities, standard_gravity, contact, contacts_params, velocity_representation, time)
    [176](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:176) time_ns = (
    [177](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:177)     jnp.array(time * 1e9, dtype=jnp.uint64)
    [178](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:178)     if time is not None
    [179](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:179)     else jnp.array(0, dtype=jnp.uint64)
    [180](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:180) )
    [182](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:182) if isinstance(model.contact_model, SoftContacts):
    [183](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:183)     contacts_params = (
    [184](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:184)         contacts_params
    [185](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:185)         if contacts_params is not None
--> [186](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:186)         else js.contact.estimate_good_soft_contacts_parameters(
    [187](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:187)             model=model, standard_gravity=standard_gravity
    [188](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:188)         )
    [189](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:189)     )
    [190](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:190) else:
    [191](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/data.py:191)     contacts_params = model.contact_model.parameters

File ~/repos/jaxsim/src/jaxsim/api/contact.py:270, in estimate_good_soft_contacts_parameters(model, standard_gravity, static_friction_coefficient, number_of_active_collidable_points_steady_state, damping_ratio, max_penetration)
    [263](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:263)         return 2 * (W_pz_CoM - W_pz_C.min())
    [265](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:265)     return 2 * W_pz_CoM
    [267](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:267) max_δ = (
    [268](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:268)     max_penetration
    [269](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:269)     if max_penetration is not None
--> [270](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:270)     else 0.005 * estimate_model_height(model=model)
    [271](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:271) )
    [273](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:273) nc = number_of_active_collidable_points_steady_state
    [275](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:275) sc_parameters = SoftContactsParams.build_default_from_jaxsim_model(
    [276](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:276)     model=model,
    [277](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:277)     standard_gravity=standard_gravity,
   (...)
    [281](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:281)     damping_ratio=damping_ratio,
    [282](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:282) )

File ~/repos/jaxsim/src/jaxsim/api/contact.py:259, in estimate_good_soft_contacts_parameters.<locals>.estimate_model_height(model)
    [252](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:252) """"""
    [254](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:254) zero_data = js.data.JaxSimModelData.build(
    [255](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:255)     model=model,
    [256](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:256)     contacts_params=SoftContactsParams(),
    [257](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:257) )
--> [259](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:259) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
    [261](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:261) if model.floating_base():
    [262](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/contact.py:262)     W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]

File ~/repos/jaxsim/src/jaxsim/api/com.py:29, in com_position(model, data)
     [16](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:16) """
     [17](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:17) Compute the position of the center of mass of the model.
     [18](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:18) 
   (...)
     [24](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:24)     The position of the center of mass of the model w.r.t. the world frame.
     [25](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:25) """
     [27](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:27) m = js.model.total_mass(model=model)
---> [29](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:29) W_H_L = js.model.forward_kinematics(model=model, data=data)
     [30](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:30) W_H_B = data.base_transform()
     [31](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/com.py:31) B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()

File ~/repos/jaxsim/src/jaxsim/api/model.py:441, in forward_kinematics(model, data)
    [427](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:427) @jax.jit
    [428](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:428) def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
    [429](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:429)     """
    [430](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:430)     Compute the SE(3) transforms from the world frame to the frames of all links.
    [431](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:431) 
   (...)
    [438](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:438)         The first axis is the link index.
    [439](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:439)     """
--> [441](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:441)     W_H_LL = jaxsim.rbda.forward_kinematics_model(
    [442](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:442)         model=model,
    [443](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:443)         base_position=data.base_position(),
    [444](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:444)         base_quaternion=data.base_orientation(dcm=False),
    [445](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:445)         joint_positions=data.joint_positions(model=model),
    [446](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:446)     )
    [448](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/api/model.py:448)     return jnp.atleast_3d(W_H_LL).astype(float)

File ~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:78, in forward_kinematics_model(model, base_position, base_quaternion, joint_positions)
     [74](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:74)     W_X_i = W_X_i.at[i].set(W_X_i_i)
     [76](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:76)     return (W_X_i,), None
---> [78](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:78) (W_X_i,), _ = jax.lax.scan(
     [79](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:79)     f=propagate_kinematics,
     [80](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:80)     init=propagate_kinematics_carry,
     [81](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:81)     xs=jnp.arange(start=1, stop=model.number_of_links()),
     [82](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:82) )
     [84](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/repos/jaxsim/src/jaxsim/rbda/forward_kinematics.py:84) return jax.vmap(Adjoint.to_transform)(W_X_i)

    [... skipping hidden 1 frame]

File ~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:231, in scan(f, init, xs, length, reverse, unroll, _split_transpose)
    [229](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:229) if config.disable_jit.value:
    [230](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:230)   if length == 0:
--> [231](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:231)     raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
    [232](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:232)   carry = init
    [233](https://file+.vscode-resource.vscode-cdn.net/home/acroci/repos/component_alpha/~/mambaforge/envs/jaxsim/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:233)   ys = []

ValueError: zero-length scan is not supported in disable_jit() mode because the output type is unknown.
diegoferigo commented 4 days ago

I already encountered this problem in the past for similar reasons. Refer to:

In that occasion, I fixed only the RBDAs that I needed. In your case, there are other ones that fail for similar reasons. Can you try to apply something similar to the following to exclude running the scan call?

https://github.com/ami-iit/jaxsim/blob/7340d43172d5ec528f3f7547d532ffeb9770355d/src/jaxsim/rbda/aba.py#L163-L171

flferretti commented 4 days ago

This can be added to https://github.com/ami-iit/jaxsim/pull/186

diegoferigo commented 4 days ago

This can be added to #186

Probably it's time to merge that PR. We're already a bit further than what I define "minor changes", as it often happens :smile:

traversaro commented 4 days ago

I already encountered this problem in the past for similar reasons. Refer to:

* [Add new test suite of functional APIs #106](https://github.com/ami-iit/jaxsim/pull/106)

* [945f04b](https://github.com/ami-iit/jaxsim/commit/945f04b683c3519772ad4ec7bb916bacd4400a3f)

In that occasion, I fixed only the RBDAs that I needed. In your case, there are other ones that fail for similar reasons. Can you try to apply something similar to the following to exclude running the scan call?

https://github.com/ami-iit/jaxsim/blob/7340d43172d5ec528f3f7547d532ffeb9770355d/src/jaxsim/rbda/aba.py#L163-L171

Cool thanks, we had the intuition that a workaround something that was necessary, but we were a bit clueless. @xela-95 probably you can open a PR yourself with the fix proposed by @diegoferigo ?

xela-95 commented 4 days ago

Cool thanks, we had the intuition that a workaround something that was necessary, but we were a bit clueless. @xela-95 probably you can open a PR yourself with the fix proposed by @diegoferigo ?

Sure, I'll try to see if this fixes the issue and then open a PR :)

traversaro commented 4 days ago

xref other jax issues:

The fix suggested by @diegoferigo is useful and may be useful for users finding related jax issues on search engines.

diegoferigo commented 4 days ago

It's worth noting that (if I don't mistake) the fix works in our case only because the condition of the if operates on a static element (following model, kin_dyn_parameters, link_names). I fear that it won't work if the condition cannot be evaluated statically. In that case, using jax.lax.cond might be necessary.

traversaro commented 4 days ago

Most cases we saw were indeed due to model.number_of_links() > 1 (and probably went unnoticed as it is not so common to integrate a rigid body without joints).

diegoferigo commented 3 days ago

Most cases we saw were indeed due to model.number_of_links() > 1 (and probably went unnoticed as it is not so common to integrate a rigid body without joints).

We actually do support that, and single-body models are also part of our test suite (together with a fixed-based and a floating-base model). This went unnoticed because JIT is automatically enabled in tests, and JIT-compiled jax.lax.scan do not complain if there is no actual iteration. They complain only if called either with JAX_DISABLE_JIT or inside a jax.disable_jit context.