google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.38k stars 257 forks source link

Evaluation only returns the cumulative reward of half of the episode #516

Closed fvgt closed 2 months ago

fvgt commented 3 months ago

I was using the sac train.py function, that is available in brax. When I took a look at the full return of the unrolled scan for the evaluation, i.e. I removed the [0] index at the end of this function:

    def generate_eval_unroll(policy_params: PolicyParams,
                             key: PRNGKey) -> State:
      reset_keys = jax.random.split(key, num_eval_envs)
      eval_first_state = eval_env.reset(reset_keys)
      return generate_unroll(
          eval_env,
          eval_first_state,
          eval_policy_fn(policy_params),
          key,
          unroll_length=episode_length // action_repeat)[0]

Now the evaluation returns the eval state (that includes the eval metrics that are used for logging, i.e. eval_state.info['eval_metrics']), and the data of the full scan. For example, I can look at the full discounts of the episode:

data.discount.shape
(1000, 10)

My settings were an episode length of 1000, with an action repeat of 1, so the rollout length is 1000 (the first dimension of data), and I used 10 envs (the second dim of the data). Then, I took a look at the discounts:

data.discount
Array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

That made sense on a first glance, the episode terminated after 1000 steps (I was using half cheetah). However, it also terminated in between:

np.where(data.discount == 0.0)
(array([499, 499, 499, 499, 499, 499, 499, 499, 499, 499, 999, 999, 999,
       999, 999, 999, 999, 999, 999, 999]), array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))

This is an issue, because the summed reward also do not make any sense. For example, If I want to compute the full return of each episode (not discounted), I would use

sum(data.reward)
Array([-207.69978, -233.7913 , -214.13663, -199.6694 , -301.35718,
       -304.1123 , -393.22888, -160.33006, -276.96848, -348.44727],      dtype=float32)

Compare this with the cumulative reward, that is already computed in the eval metrics:

eval_metrics.episode_metrics['reward']
Array([-120.13453 , -108.668724,  -69.54802 , -136.98099 , -211.82776 ,
       -100.90355 , -144.17926 ,  -73.8679  , -148.59505 , -150.05788 ],      dtype=float32)

They are different, but we can easily get them by just computing the reward to the 500th time step:

sum(data.reward[:500])
Array([-120.13453 , -108.668724,  -69.54802 , -136.98099 , -211.82776 ,
       -100.90355 , -144.17926 ,  -73.8679  , -148.59505 , -150.05788 ],      dtype=float32)

So I am not sure if that is a bug or if that is intended? If this is intended, what is the reason for this?

Edit:

This bug was on my side. The issue was that I was creating the env using

envs.get_environment(env_name)

instead of using

envs.get_environment(env_name)

which works fine. I think the issue is that the first method of getting the environment already wraps the env. So calling the wrap function again, like it is done in the SAC training pipeline, will create a double wrapping which will lead to unwanted behaviors.

erikfrey commented 2 months ago

Glad you got it sorted - I agree that the functions envs.get_environment vs envs.create are not the most descriptive in terms of telling you what they're actually doing.

Just between you and me (and the rest of the internet), this whole envs.register, envs.create business is a bit of an overwrought abstraction. I've found it simpler to just import the env and instantiate it and wrap it myself.