I was training my own PPO agent in the Breakout-MinAtar environment (with the standard config based on the provided yaml files in this repo). I have it saved in a .pkl now, and loaded the model and params with the load_neural_network function.
I wanted to try benchmark my model (from the .pkl) based on the "getting-started" (from gymnax) notebook for the def rollout(...) function code snippet to try and benchmark how good my own trained model was. Below is the code:
UnfilteredStackTrace
Traceback (most recent call last)
[<ipython-input-37-2f44bbc81671>](https://localhost:8080/#) in <module>
38 # Scan over episode step loop
---> 39 _, scan_out = jax.lax.scan(
40 policy_step,
=====================================================
28 frames
=====================================================
UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
IndexError Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _canonicalize_tuple_index(arr_ndim, idx, array_name)
4257 len_without_none = _sum(1 for e in idx if e is not None and e is not Ellipsis)
4258 if len_without_none > arr_ndim:
-> 4259 raise IndexError(
4260 f"Too many indices for {array_name}: {len_without_none} "
4261 f"non-None/Ellipsis indices for dim {arr_ndim}.")
IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
Okay, cool so some kind of indexing error. So I tried following the code:
The error specifically happens on the following line, in the policy_step() function:
To which I get: ACT: (Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>, <tfp.distributions.Categorical 'Categorical' batch_shape=[] event_shape=[] dtype=int32>)
So I dug deeper, and went down to the actual neural network models that was implemented (checking that this indeed, was the one being used):
class CategoricalSeparateMLP(nn.Module):
"""Split Actor-Critic Architecture for PPO."""
...
logits = nn.Dense(
self.num_output_units,
bias_init=default_mlp_init(),
)(x_a)
# pi = distrax.Categorical(logits=logits)
pi = tfp.distributions.Categorical(logits=logits)
return v, pi
This is where I should say I am fairly new to JAX, and maybe don't understand in detail how Traced<ShapedArray> stuff works in full.
means that this is a Traced<ShapedArray> which hasn't fully finished computing (letting lazy evaluation with all the JAX functional programming stuff do it later), and contains the pi output from the tfp.distributions.Categorical(logits=logits) in the definition of CategoricalSeparateMLP.
And so 2 questions:
Is there an easier way/better examples of benchmarking the performance of the trained models?
How could the above error be fixed? I tried doing for example action = action.sample() according to tfp.distributions.Categorical(logits=logits) and passing that in. However, that resulted in somehow messing with the model inference in the above line for action = model.apply(policy_params, obs, rng_net).
(Optional) What was the reasoning for using tfp.distributions.Categorical(logits=logits)? I found it odd that pi = distrax.Categorical(logits=logits) was commented out despite distrax being the native JAX-supported library, and I tried doing something similar by using the action = action._sample_n(rng,1)[0] with distrax, but that seemed to fail in the same step_env too in a similar way. (i.e. with the same UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1)
I was training my own PPO agent in the
Breakout-MinAtar
environment (with the standard config based on the providedyaml
files in this repo). I have it saved in a.pkl
now, and loaded the model and params with theload_neural_network
function.I wanted to try benchmark my model (from the
.pkl
) based on the "getting-started" (fromgymnax
) notebook for thedef rollout(...)
function code snippet to try and benchmark how good my own trained model was. Below is the code:However, I get the following error:
Okay, cool so some kind of indexing error. So I tried following the code:
The error specifically happens on the following line, in the
policy_step()
function:Digging into the error traces, it seems like the error is from here, from the
step_env
from the Breakout-MiniAtar implementation:So I tried debugging what the
action
being passed was by printing it in the code, like below:To which I get:
ACT: (Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>, <tfp.distributions.Categorical 'Categorical' batch_shape=[] event_shape=[] dtype=int32>)
So I dug deeper, and went down to the actual neural network models that was implemented (checking that this indeed, was the one being used):
This is where I should say I am fairly new to JAX, and maybe don't understand in detail how
Traced<ShapedArray>
stuff works in full.But as far as I understand: the
(Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>, <tfp.distributions.Categorical 'Categorical' batch_shape=[] event_shape=[] dtype=int32>)
means that this is a
Traced<ShapedArray>
which hasn't fully finished computing (letting lazy evaluation with all the JAX functional programming stuff do it later), and contains thepi
output from thetfp.distributions.Categorical(logits=logits)
in the definition ofCategoricalSeparateMLP
.And so 2 questions:
action = action.sample()
according totfp.distributions.Categorical(logits=logits)
and passing that in. However, that resulted in somehow messing with the model inference in the above line foraction = model.apply(policy_params, obs, rng_net)
.tfp.distributions.Categorical(logits=logits)
? I found it odd thatpi = distrax.Categorical(logits=logits)
was commented out despitedistrax
being the native JAX-supported library, and I tried doing something similar by using theaction = action._sample_n(rng,1)[0]
with distrax, but that seemed to fail in the samestep_env
too in a similar way. (i.e. with the sameUnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1
)