nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
399 stars 94 forks source link

Log prob computation in TDMPC2 #44

Open wertyuilife2 opened 3 months ago

wertyuilife2 commented 3 months ago

I have some questions about the way logprob is calculated in TDMPC2.

The code below shows that the logprob result in TDMPC2 seems to differ from torchrl and SAC.

import torch
import torch.nn.functional as F
from torchrl.modules.distributions import TanhNormal

@torch.jit.script
def _gaussian_residual(eps, log_std):
    return -0.5 * eps.pow(2) - log_std

@torch.jit.script
def _gaussian_logprob(residual):
    return residual - 0.5 * torch.log(2 * torch.pi)

def gaussian_logprob(eps, log_std, size=None):
    """
    origin tdmpc2 code
    multiply action dim size on logprob
    and first sum, then minus constant number.
    """
    residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
    if size is None:
        size = eps.size(-1)
    return _gaussian_logprob(residual) * size

def gaussian_logprob_sac(eps, log_std):
    """SAC gaussian logprob code"""
    residual = _gaussian_residual(eps, log_std)
    return _gaussian_logprob(residual).sum(-1, keepdim=True)    

@torch.jit.script
def _squash(pi):
    return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)

def squash(mu, pi, log_pi): 
    """Apply squashing function."""
    mu = torch.tanh(mu)
    pi = torch.tanh(pi)
    log_pi -= _squash(pi).sum(-1, keepdim=True)
    return mu, pi, log_pi

if __name__ == "__main__":
    mu = torch.tensor([[-0.2,-0.1],[0.1,0.2]],dtype=torch.float)
    log_std = torch.tensor([[-2,-1],[0,1]],dtype=torch.float)
    dist = TanhNormal(loc=mu, scale=log_std.exp())  

    eps = torch.tensor([[0.1,0.2],[0.3,0.4]],dtype=torch.float)
    pi = mu + eps * log_std.exp()

    # compute logprob in sac
    log_pi = gaussian_logprob_sac(eps, log_std)
    _, _, log_pi = squash(mu, pi, log_pi)
    print("sac logprob:", log_pi)

    # compute logprob in tdmpc
    log_pi = gaussian_logprob(eps, log_std, size=None)
    mu, pi, log_pi = squash(mu, pi, log_pi)
    print("tdmpc logprob:", log_pi)

    # compute logprob using torchrl
    logprob = dist.log_prob(pi)
    print("torchrl logprob:", logprob)

code result:

sac logprob: tensor([[ 1.1724],
        [-1.4718]])
tdmpc logprob: tensor([[ 4.1474],
        [-2.5968]])
torchrl logprob: tensor([ 1.1724, -1.4718])

Specifically, TDMPC2 multiplies the logprob result by the action dimension during calculation. Could you please explain why this is done?

nicklashansen commented 2 weeks ago

Thanks for pointing out this discrepancy! The original motivation for this implementation was to account for masked out action dimensions in the multi-task case, but I agree that should ideally be addressed in a way that does not alter the log probabilities when no masking is applied. I'll try to issue a fix soon. Either way, I don't believe this discrepancy should affect results in any meaningful way.

wertyuilife2 commented 1 week ago

Thanks for the explanation! I also agree that this isn’t a fundamental issue—I was just genuinely curious about the reasoning behind this design choice.

However, based on my previous experiments, at least on the Dog-Run task in DMControl, multiplying by the action dim does seem to bring a noticeable performance improvement compared to not doing so. I suggest paying attention to the performance on this task when applying the fix.

nicklashansen commented 1 week ago

Interesting! It's possible that the scaling increases policy entropy a bit for tasks with large action spaces? I wonder if scaling the entropy coef with action space dim would yield the same result.

wertyuilife2 commented 1 week ago

Scaling the entropy coef and scaling the log prob don't seem mathematically identical due to the additional squash() step, but I believe they will have similar effects.

ShaneFlandermeyer commented 4 days ago

I'm also curious about this. I wouldn't expect to have to scale the entropy/logprobs at all since they're already proportional to the action dim, but as @wertyuilife2 said I've also seen performance drops with the corrected equations.

ShaneFlandermeyer commented 2 days ago

And because @nicklashansen wanted to test this, here's the dog-run training reward for the original implementation (gray) and scaling the entropy coefficient directly with the updated equation (blue):

Screenshot from 2024-11-28 09-09-18

nicklashansen commented 2 days ago

Interesting, thanks for running this! I wonder if this is more of a one-off phenomenon for Dog or if it applies to high-dimensional action spaces in general. I would caution against making conclusions based off of a small set of experiments since it's still RL after all. I'll see if I can run some experiments on this soon :-)