waymo-research / waymax

A JAX-based simulator for autonomous driving research.
Other
799 stars 85 forks source link

Discrete Wrapper incompatible with jax.lax.scan for PlanningAgent dynamics/environment #37

Open clemgris opened 7 months ago

clemgris commented 7 months ago

I am currently working with the InvertibleBicycleModel and PlanningAgent dynamics and environment in my project. I've encountered a jax.errors.ConcretizationTypeError when I apply the Discrete wrapper and attempt to update the simulator state within a scan function. No error is raised when updating the simulator state within a regular loop.

This issue specifically arises when the Discrete wrapper is applied. Removing the Discretize wrapper prevents the error from occurring.

Reproduction Steps:

Initialisation of the dataset.

import jax
import jax.numpy as jnp
import dataclasses

from waymax import agents
from waymax import config as _config
from waymax import dataloader
from waymax import dynamics
from waymax import env as _env

WOD_1_1_0_VALIDATION = _config.DatasetConfig(
    path='gs://waymo_open_dataset_motion_v_1_1_0/uncompressed/tf_example/validation/validation_tfexample.tfrecord@150',
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
)

config = dataclasses.replace(WOD_1_1_0_VALIDATION, max_num_objects=32)
data_iter = dataloader.simulator_state_generator(config=config)
scenario = next(data_iter)

Initialisation of the environment and dynamics (apply the Discrete Wrapper) as well as an expert agent.

env_config = _config.EnvironmentConfig(
    controlled_object=_config.ObjectType.SDC,
    max_num_objects=32
)

wrapped_dynamics_model = dynamics.InvertibleBicycleModel()
dynamics_model = _env.PlanningAgentDynamics(wrapped_dynamics_model)

action_space_dim = dynamics_model.action_spec().shape
dynamics_model = dynamics.discretizer.DiscreteActionSpaceWrapper(dynamics_model=dynamics_model,
                                                                        bins=128 * jnp.array(action_space_dim))

env = _env.PlanningAgentEnvironment(
    dynamics_model=wrapped_dynamics_model,
    config=env_config,
    )

expert_agent = agents.create_expert_actor(dynamics_model)

The loop to update the simulator state from the expert action does not raise any error.

current_state = env.reset(scenario) 

for _ in range(10):
    expert_action = expert_agent.select_action(state=current_state, params=None, rng=None, actor_state=None).action
    current_state = env.step(current_state, expert_action)

The same loop done within a scan function raises the jax.errors.ConcretizationTypeError error.

current_state = env.reset(scenario)               

def _env_step(current_state, unused):

    expert_action = expert_agent.select_action(state=current_state, params=None, rng=None, actor_state=None).action
    current_state = env.step(current_state, expert_action)

    return current_state, None

current_state, _ = jax.lax.scan(f=_env_step, init=current_state, xs=None, length=10)

The Traceback of the error is:

Traceback (most recent call last):
  File "/home/chetouani/Documents/INTERNSHIP_Clemence_Oxford/debug_imitation_gap_WOMD/issues/discrete_wrapper.py", line 63, in <module>
    current_state, _ = jax.lax.scan(f=_env_step, init=current_state, xs=None, length=10)
  File "/home/chetouani/Documents/INTERNSHIP_Clemence_Oxford/debug_imitation_gap_WOMD/issues/discrete_wrapper.py", line 57, in _env_step
    expert_action = expert_agent.select_action(state=current_state, params=None, rng=None, actor_state=None).action
  File "/home/chetouani/miniforge-pypy3/envs/womd/lib/python3.10/site-packages/waymax/agents/actor_core.py", line 163, in select_action
    return select_action(params, state, actor_state, rng)
  File "/home/chetouani/miniforge-pypy3/envs/womd/lib/python3.10/site-packages/waymax/agents/expert.py", line 96, in select_action
    logged_action = infer_expert_action(state, dynamics_model)
  File "/home/chetouani/miniforge-pypy3/envs/womd/lib/python3.10/site-packages/waymax/agents/expert.py", line 58, in infer_expert_action
    return dynamics_model.inverse(
  File "/home/chetouani/miniforge-pypy3/envs/womd/lib/python3.10/site-packages/waymax/dynamics/discretizer.py", line 205, in inverse
    data=self._discretizer.discretize(action_cont.data),
  File "/home/chetouani/miniforge-pypy3/envs/womd/lib/python3.10/site-packages/waymax/dynamics/discretizer.py", line 76, in discretize
    indices_1d = jnp.ravel_multi_index(  # pytype: disable=wrong-arg-types  # jnp-type
  File "/home/chetouani/miniforge-pypy3/envs/womd/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 804, in ravel_multi_index
    dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
  File "/home/chetouani/miniforge-pypy3/envs/womd/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 804, in <genexpr>
    dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
in `dims` argument of ravel_multi_index().
The error occurred while tracing the function _env_step at /home/chetouani/Documents/INTERNSHIP_Clemence_Oxford/debug_imitation_gap_WOMD/issues/discrete_wrapper.py:55 for scan. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = add b c
    from line /home/chetouani/Documents/INTERNSHIP_Clemence_Oxford/debug_imitation_gap_WOMD/issues/discrete_wrapper.py:57 (_env_step)

  operation a:f32[1] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line /home/chetouani/Documents/INTERNSHIP_Clemence_Oxford/debug_imitation_gap_WOMD/issues/discrete_wrapper.py:57 (_env_step)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /home/chetouani/Documents/INTERNSHIP_Clemence_Oxford/debug_imitation_gap_WOMD/issues/discrete_wrapper.py:57 (_env_step)

  operation a:i32[1] = add b c
    from line /home/chetouani/Documents/INTERNSHIP_Clemence_Oxford/debug_imitation_gap_WOMD/issues/discrete_wrapper.py:57 (_env_step)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
hungdche commented 4 months ago

Hi. Have you manage to resolve the issue?

clemgris commented 4 months ago

Hi, no, the issue is still unresolved.

vuoristo commented 2 months ago

I think the problem here is that the bins argument of the wrapper ends up being used in the discretize function in a way that requires it to be a static value.

So instead of a jnp.array you could for example use a regular np.array where you create the DiscreteActionSpaceWrapper.

dynamics_model = dynamics.discretizer.DiscreteActionSpaceWrapper(
    dynamics_model=dynamics_model,
    bins=128 * np.array(action_space_dim)
)