luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
615 stars 53 forks source link

linear schedule is not linear (?) #6

Closed Howuhh closed 9 months ago

Howuhh commented 9 months ago

Hi! Noticed that the linear schedule for learning rate for a small number of steps is actually not linear. I doubt it makes much difference to the final results, but thought I'd show it anyway. Maybe I made a mistake somewhere?

total_timesteps = 10_000
num_minibatches = 32
num_steps = 128
num_envs = 4
num_epochs = 10
lr = 0.1
num_updates = total_timesteps // num_steps // num_envs

def linear_schedule_purejaxrl(count):
    frac = 1.0 - (count // (num_minibatches * num_epochs)) / num_updates
    return lr * frac

def linear_schedule(count):
    frac = 1.0 - count / (num_epochs * num_minibatches * num_updates)
    return lr * frac

total_gradient_updates = (num_minibatches * num_epochs) * num_updates

plt.plot([linear_schedule(i) for i in range(1, total_gradient_updates + 1)], label="linear")
plt.plot([linear_schedule_purejaxrl(i) for i in range(1, total_gradient_updates + 1)], label="purejaxrl")
plt.legend()
plt.xlabel("Update")
plt.ylabel("Learning Rate")

Result: example

luchris429 commented 9 months ago

Ahh thank you very much! I'll swap yours in, or it would be great if you could make a PR.

luchris429 commented 9 months ago

Hmm, actually I think it depends on what we mean by "update".

I believe mine is based on CleanrRL's implementation, which updates the learning rate every "PPO Update", not every "Gradient Update".

Howuhh commented 9 months ago

Yeah, seems like it is indeed. Sorry for the confusion. Thanks for the implementation anyway!