keraJLi / rejax

Apache License 2.0
150 stars 7 forks source link

Great performance degradation when increasing 'eval_freq' #13

Closed e-zorzi closed 2 months ago

e-zorzi commented 2 months ago

Hi, thanks for the repo! I'm trying it out, and I love that I can run experiments on complex environments in less than 10x the time than other non-jax libraries! Also, I love that it's a bit more modular than other jax-libraries!

Today, I was testing quick experiments on different environments such as 'Cartpole' and 'Halhcheetah,' and I noticed that when I increase the eval_freq the performance goes down a lot. For example, I wanted to collect more data than every 32768 steps in halfcheetah using PPO, and I set it to around one tenth, like 3000 steps, but it was much much slower! I did some tests on Halfcheetah and Cartpole and I found the following:

For Hafcheetah Total_steps eval_freq Time (s)
50_000 25_000 134
50_000 2_500 404
For Cartpole Total_steps eval_freq Time (s)
50_000 500 12
50_000 50 22

I understand that I may incur in some overhead by evaluating more frequently so I tried changing the number of seeds in the evaluation (from 200 to 20) and also the number of environments and steps per environment but nothing changed.

Is there a way to mitigate this overhead? Is it due to how jax.jit work? I tried jitting the function returned by 'make_eval' but nothing changed (I'm not very expert in Jax so this was probably a stupid idea). Thanks for the great repo!

keraJLi commented 2 months ago

Hi @e-zorzi! I'm glad that you're enjoying Rejax ❤️

Since I don't know your exact setup, I don't think I can perfectly point out the reason for your issue. However, I strongly believe the slow-down you are experiencing is due to the increased number of steps that must be simulated for evaluation.

Why does this happen?

Let's look at HalfCheetah as an example, and assume that you're evaluating on a single seed only. You're training for 50k time steps, and evaluating for 50_000 / eval_freq * episode_length time steps. For an eval_freq of 2k / 25k, this will come out to be 2k / 20k steps. Your times for halfcheetah seem to imply a stronger slow-down, this might be because your GPU needs more time to load from memory for 200 environments.

To confirm this, I've run some experiments for CartPole, using the configuration in the configs folder: eval_freq max_steps_in_episode (eval) num_seeds (eval) total eval steps Time (s)
500 500 200 100.000 70
500 50 200 10.000 34
500 0 200 0 30
5000 500 200 10.000 35
5000 50 200 1.000 3.5

As you can see, the training takes 30 seconds, and over the course of training, 5s / 10k * steps seconds are used for evaluation. This is pretty consistent if I plug in other values.

Another reason this could happen is that you actually train for more than total_timesteps steps, because of some weird hparams. You can easily exclude this by looking at the train_state after training.

How can you avoid it?

This is the hard part. Currently, the evaluation logic is set up in this way because it's convenient, and it's slow because it's sequential (we cannot parallelize across time). We can address both of these points, with some caveats.

Method 1: Evaluation during training

Other frameworks, such as cleanRL log returns of the episodes used in training. This means no additional environment simulations! However, it's non-trivial to set up using Gymnax. We would like to implement an environment wrapper that keeps track of the current episode's return and stores the final return after it has finished. However, Gymnax implements an auto-reset that does not carry over state (in which we would store the return), which makes this hard. There are two ways to address this:

  1. Overwrite gymnax.environments.environment.step to carry over state from the terminated environment. This is probably the cleanest option currently.
  2. Track current return within the train_state on an algorithmic level. This will become much easier in v0.1.0 (releasing soon), where this could be implemented as a mixin.

Method 2: Parallel evaluation after training

Instead of the current return, your eval_callback could return a snapshot of model, and you can vmap evaluation across model parameters after training. This might be a very bad idea if you have larger networks or evaluate often since this quickly fills up memory.

Final words

So, what should you do? This depends on your requirements. Why do you need such a high resolution of your training curve? How much GPU memory do you have? How much effort do you want to put into making it work? Unfortunately, there is no easy solution.

Quick final note: if you jit the train function, your eval_callback will be jitted alongside it. You should actually not jit it individually, since it gives the compiler a little less room to optimize.

Let me know if you have any more questions! 🤗

e-zorzi commented 2 months ago

Thanks for the answer, very thorough! Evaluation is a very important topic in RL, but I understand that every library does it differently. I don't have particular reasons for collecting more frequent data, I was just surprised by the time difference! I will reduce the maximum length if I need it. The way evaluation is carried out right now seems reasonable; the only question that I have is if, across different training seeds, the evaluation environments are the same. For example, I used the standard 200 environments for ppo eval, when training over 5 different seeds. Are the 200 environments used to evaluate the run with seed 1 different than those used to evaluate the training run with seed 2, let's say? Looking at the code, it seems so to me. Wouldn't it be better, for consistency's sake, to always use the same seeds for the eval environments, independently of the training run's rng? In that case, it would be like having always the same 200 environments, just different policies due to the training seeds. To be honest, I don't think it would change too much. By using 200 environments we already reduce variance a lot, but this would be noticeable when using fewer testing environments.

Also, have you considered writing a small wrapper / eval callback that collects data in the right shape for rliable ? It's a nice way to do evaluations the right way, and reviewers do like these kind of plots!

Ps. Btw yes, the larger slowdown might be due to some incompatibility between my CUDA and my JAX installations. I always get warnings that parallel compilation is disabled. I also use five seeds instead of one, tho I don't think it changes anything.

keraJLi commented 2 months ago

Using the same evaluation seed in each vmapped run can easily be implemented by overwriting the eval callback. E.g.

def eval_callback(algo, ts, rng):
    rng = jax.random.PRNGKey(ts.global_step)  # or some hardcoded value
    ...

Keep in mind that reducing the length of evaluation essentially changes the task. For example in CartPole, it's much easier to balance for 200 steps than it is for 500, because when approaching 500 steps you're much more likely to terminate by reaching the boundaries. This is why the default evaluation length was actually changed in CartPole-v1.

I'm not sure if I understand your suggestion related to rliable. The input format for rliable is num_runs x num_games x frames (where frames refers to total_timesteps / eval_freq). Since we cannot vmap across different environments, this would leave us with num_runs x frames, which is already the shape of the evaluation that is returned. Could you elaborate on the format you were thinking of?

keraJLi commented 2 months ago

I hope I could help! I'll close the issue for now, feel free to reopen it if anything was unclear :-)