bmazoure / ppo_jax

Jax implementation of Proximal Policy Optimization (PPO) specifically tuned for Procgen, with benchmarked results and saved model weights on all environments.
48 stars 1 forks source link

PPO's performance #2

Open vwxyzjn opened 2 years ago

vwxyzjn commented 2 years ago

Hi @bmazoure,

Your PPO +JAX implementation caught my eyes and this is a really cool repo!

Based on your benchmark with W&B, I compared the performance of your implementation with mine and openai/baselines in this report. Here are some performance differences:

image

I feel the reason for the difference might be:

  1. Missing value bootstrapping in GAE: ppo_jax does not seem to bootstrap value if the environment is not terminated (buffer.py#L18-L26), whereas the original implementation does this (ppo2/runner.py#L56-L65) (ppo2/runner.py#L50)
  2. Slightly off layer initialization. ppo_jax uses the same initialization scale for both value and policy (models.py#L91). However, the scale of the value function should be initialized with scale 1 instead of 0.01 (common/policies.py#L49-L63)

What do you think?

bmazoure commented 2 years ago

Hey, Thanks for going through to analyze this! Yeah, some environments definitively have inconsistencies. You're right, I did notice a lot of impact in the different inits, so you could try mimicking the TF1 init from stable_baselines. Regarding the termination bonus, in the snippet you linked, I made sure to have the for loop iterate over n_steps+1 (see here https://github.com/bmazoure/ppo_jax/blob/main/train_ppo.py#L101), so technically the termination bonus should be included.

If you want, you can open a PR with some changes, and I'll try to re-run the agent with value function init with scale 1 and update the weights.