RobertTLange / gymnax-blines

Baselines for gymnax 🤖
Apache License 2.0
57 stars 13 forks source link

UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1. #7

Open Miyamura80 opened 1 year ago

Miyamura80 commented 1 year ago

I was training my own PPO agent in the Breakout-MinAtar environment (with the standard config based on the provided yaml files in this repo). I have it saved in a .pkl now, and loaded the model and params with the load_neural_network function.

I wanted to try benchmark my model (from the .pkl) based on the "getting-started" (from gymnax) notebook for the def rollout(...) function code snippet to try and benchmark how good my own trained model was. Below is the code:

...
obs, state = env.reset(rng_reset, env_params)

def policy_step(state_input, tmp):
    """lax.scan compatible step transition in jax env."""
    obs, state, policy_params, rng = state_input
    rng, rng_step, rng_net = jax.random.split(rng, 3)
    print("dsajid")
    action = model.apply(policy_params, obs, rng_net)

    print("ACT: ", action)
    # action = action._sample_n(rng,1)[0]

    next_obs, next_state, reward, done, _ = env.step(
      rng_step, state, action, env_params
    )
    carry = [next_obs, next_state, policy_params, rng]
    return carry, [obs, action, reward, next_obs, done]

# Scan over episode step loop
_, scan_out = jax.lax.scan(
  policy_step,
  [obs, state, params, rng_episode],
  (),
  steps_in_episode
)

However, I get the following error:

UnfilteredStackTrace                      
Traceback (most recent call last)
[<ipython-input-37-2f44bbc81671>](https://localhost:8080/#) in <module>
     38 # Scan over episode step loop
---> 39 _, scan_out = jax.lax.scan(
     40   policy_step,

=====================================================
28 frames
=====================================================

UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

IndexError                                Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _canonicalize_tuple_index(arr_ndim, idx, array_name)
   4257   len_without_none = _sum(1 for e in idx if e is not None and e is not Ellipsis)
   4258   if len_without_none > arr_ndim:
-> 4259     raise IndexError(
   4260         f"Too many indices for {array_name}: {len_without_none} "
   4261         f"non-None/Ellipsis indices for dim {arr_ndim}.")

IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

Okay, cool so some kind of indexing error. So I tried following the code:

The error specifically happens on the following line, in the policy_step() function:

next_obs, next_state, reward, done, _ = env.step(
      rng_step, state, action, env_params
)

Digging into the error traces, it seems like the error is from here, from the step_env from the Breakout-MiniAtar implementation:

in step_env(self, key, state, action, params)
     75         """Perform single timestep state transition."""
---> 76         a = self.action_set[action]
     77         state, new_x, new_y = step_agent(state, a)
     78         state, reward = step_ball_brick(state, new_x, new_y)

So I tried debugging what the action being passed was by printing it in the code, like below:


print("DEBUG: ", action)
next_obs, next_state, reward, done, _ = env.step(
      rng_step, state, action, env_params
)

To which I get: ACT: (Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>, <tfp.distributions.Categorical 'Categorical' batch_shape=[] event_shape=[] dtype=int32>)

So I dug deeper, and went down to the actual neural network models that was implemented (checking that this indeed, was the one being used):


class CategoricalSeparateMLP(nn.Module):
    """Split Actor-Critic Architecture for PPO."""

       ...

        logits = nn.Dense(
            self.num_output_units,
            bias_init=default_mlp_init(),
        )(x_a)
        # pi = distrax.Categorical(logits=logits)
        pi = tfp.distributions.Categorical(logits=logits)
        return v, pi

This is where I should say I am fairly new to JAX, and maybe don't understand in detail how Traced<ShapedArray> stuff works in full.

But as far as I understand: the

(Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>, <tfp.distributions.Categorical 'Categorical' batch_shape=[] event_shape=[] dtype=int32>)

means that this is a Traced<ShapedArray> which hasn't fully finished computing (letting lazy evaluation with all the JAX functional programming stuff do it later), and contains the pi output from the tfp.distributions.Categorical(logits=logits) in the definition of CategoricalSeparateMLP.

And so 2 questions:

  1. Is there an easier way/better examples of benchmarking the performance of the trained models?
  2. How could the above error be fixed? I tried doing for example action = action.sample() according to tfp.distributions.Categorical(logits=logits) and passing that in. However, that resulted in somehow messing with the model inference in the above line for action = model.apply(policy_params, obs, rng_net).
  3. (Optional) What was the reasoning for using tfp.distributions.Categorical(logits=logits)? I found it odd that pi = distrax.Categorical(logits=logits) was commented out despite distrax being the native JAX-supported library, and I tried doing something similar by using the action = action._sample_n(rng,1)[0] with distrax, but that seemed to fail in the same step_env too in a similar way. (i.e. with the same UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1)