thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
7.63k stars 1.1k forks source link

No minibatch for computation of logp_old in PPOPolicy #1164

Open jvasso opened 2 weeks ago

jvasso commented 2 weeks ago

I have noticed that in the implementation of the PPOPolicy, the computation of the old log probabilities logp_old is performed without using minibatch:

with torch.no_grad():
   batch.logp_old = self(batch).dist.log_prob(batch.act)

This makes this algorithm unusable in situations where the batch is too large, with no possibility of controlling it via batch_size. I simply suggest to add support for minibatch:

logp_old = []
with torch.no_grad():
    for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
        logp_old.append(self(minibatch).dist.log_prob(minibatch.act))
    batch.logp_old = torch.cat(logp_old, dim=0).flatten()

The version of Tianshou that I'm using is 1.0.0.

MischaPanch commented 1 week ago

You're right, wanna make a PR for that? Otherwise I can also make one myself