LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.17k stars 325 forks source link

PPO example is actually A2C #797

Open Boxxfish opened 1 year ago

Boxxfish commented 1 year ago

I noticed while browsing the RL examples that the PPO implementation is actually A2C (which there's already an example for). On line 141, this line:

let action_loss = (-advantages.detach() * action_log_probs).mean(Kind::Float);

should look something like this:

let advantages = advantages.detach();
let term1 = (&action_log_probs - &old_action_log_probs).exp() * &advantages; // Importance sampling term
let term2 = (Tensor::ones(advantages.size(), (device, Kind::Float)) + epsilon * advantages.sign()) * &advantages; // Clipping term
let action_loss = -(term1.min(term2).mean(Kind::Float); // If importance sampling term doesn't fall between 1 - e to 1 + e, clip it

We should also be collecting old probabilities (old_action_log_probs) when collecting actions so we can perform the importance sampling.

Also, and this is a nit, on line 151,

if let Err(err) = vs.save(format!("trpo{update_index}.ot")) {

should be changed to "ppo{update_index}.ot".

Boxxfish commented 1 year ago

On second thought, it would be more straightforward to write it like this:

let advantages = advantages.detach();
let ratio =  (&action_log_probs - &old_action_log_probs).exp();
let term1 = &ratio * &advantages; // Importance sampling term
let ones = Tensor::ones(advantages.size(), (device, Kind::Float));
let term2 = Tensor::clip(&ratio, &ones - epsilon, &ones + epsilon) * &advantages; // Clipping term
let action_loss = -(term1.min(term2).mean(Kind::Float); // If importance sampling term doesn't fall between 1 - e to 1 + e, clip it