coax-dev / coax

Modular framework for Reinforcement Learning in python
https://coax.readthedocs.io
MIT License
168 stars 17 forks source link

Assertion assert_equal_shape failed for MultiDiscrete action space #21

Closed xiangyuy closed 2 years ago

xiangyuy commented 2 years ago

First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

So my questions are:

  1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
  2. Did I declare my policy function properly?
  3. Any general ideas on how to debug JAX functions?

I would appreciate any feedback. Thank you!

Error message

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Input In [25], in <cell line: 5>()
     13     transition_batch = tracer.pop()
     14     Gn = transition_batch.Rn
---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
     16     env.record_metrics(metrics)
     17 if done:

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
    127 def update(self, transition_batch, Adv):
    128     r"""
    129 
    130     Update the model parameters (weights) of the underlying function approximator.
   (...)
    147 
    148     """
--> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
    150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
    151         raise RuntimeError(f"found nan's in grads: {grads}")

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
    212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
    213     warnings.warn(
    214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
    215         "should be non-zero. Please sample actions with their propensities: "
    216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
    217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
--> 218 return self._grad_and_metrics_func(
    219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
    220     transition_batch, Adv)

File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
     58 def __call__(self, *args, **kwargs):
---> 59     return self._jitted_func(*args, **kwargs)

    [... skipping hidden 14 frame]

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
     77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
     78     grads_func = jax.grad(loss_func, has_aux=True)
     79     grads, (metrics, state_new) = \
---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
     82     # add some diagnostics of the gradients
     83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))

    [... skipping hidden 10 frame]

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
     45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
     46     objective, (dist_params, log_pi, state_new) = \
---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
     49     # flip sign to turn objective into loss
     50     loss = -objective

File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
     49 W = jnp.clip(transition_batch.W, 0.1, 10.)
     51 # some consistency checks
---> 52 chex.assert_equal_shape([W, Adv, log_pi])
     53 chex.assert_rank([W, Adv, log_pi], 1)
     54 objective = W * Adv * log_pi

File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
    195 else:
    196   try:
--> 197     host_assertion(*args, **kwargs)
    198   except jax.errors.ConcretizationTypeError as exc:
    199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
    200            "likely that it tried to access tensors' values during tracing. "
    201            "Make sure that you defined a jittable version of this Chex "
    202            "assertion.")

File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
    154     custom_message = custom_message.format(*custom_message_format_vars)
    155   error_msg = f"{error_msg} [{custom_message}]"
--> 157 raise exception_type(error_msg)

AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].

Example data

ExampleData(
  inputs=Inputs(
    args=ArgsType2(
      S={
        'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
      is_training=True)
    static_argnums=(
      1))
  output=(
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))

Policy function

def pi(S, is_training):
    module = CustomizedModule()
    res = tuple([{"logits": item} for item in module(S["features"])])
    return res
UweGensheimer commented 2 years ago

Hello, I was just wondering have you tried to flatten the output of the policy module and then reshape it later when it is passed to the step function of the environment. Just a simple check to pinpoint the error.

xiangyuy commented 2 years ago

@UweGensheimer Thank you for your suggestion, flatten the output is kinda complicated so instead I tried to specify the size of my MultiDiscrete to be one (i.e., policy function would return ({'logits': jnp.array}) ) and it seems to be running without error. So I think using a MultiDiscrete action space seems to be the problem here, not sure what to do next...

KristianHolsheimer commented 2 years ago

Thanks for reporting this!

I can confirm that this is a proper bug. It has to do with the way variates are pre/post-processed.

I'll have a closer look at it as soon as I have time, which is either tonight or tomorrow.

xiangyuy commented 2 years ago

Thank you so much for your speedy response! So I see #22 has passed all tests. Will it be merged to main soon or some more tests and examinations need to be done?

Update: I edited _composite.py according to 82bcd674c7b6ff2d8efd3f3d0e8e0b68184c4a84 and it seems to be running without error. Thank you again for your help!

KristianHolsheimer commented 2 years ago

Hi @xiangyuy, thanks for your patience. PR #22 is merged. I bumped the version number, because I also fixed a few other little things in the same PR.