pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.29k stars 306 forks source link

[BUG] Incorrect Calculation of Mode for TanhNormal Distribution #2186

Closed Emile-Aquila closed 4 months ago

Emile-Aquila commented 4 months ago

Describe the bug

I'd like to express my gratitude for the swift and diligent maintenance of the torchrl library.

I have identified a potential issue with the implementation of the mode method in the TanhNormal distribution within TorchRL. The calculation of the mode appears to be incorrect due to the nature of the tanh function applied to a normal distribution.

I think the mode of the TanhNormal distribution should accurately reflect the peak of the probability density function after applying the tanh transformation to the underlying normal distribution. Given the non-linearity of the tanh function, the mode calculation should account for this complexity.


To Reproduce

The current implementation of the mode method does not correctly compute the mode, resulting in inaccurate values. For example, in the following scenario, the mode is expected to be around 1, but the method returns approximately 0.197.

import torch
from torchrl.modules import TanhNormal
import matplotlib.pyplot as plt

torch.random.manual_seed(0)

loc = torch.tensor([0.2], dtype=torch.float32)
scale = torch.tensor([1.0], dtype=torch.float32)

dist = TanhNormal(loc, scale, min=-1, max=1)
print("mode: ", dist.mode.item())  # mode:  0.1973753273487091

sample = dist.sample_n(10000)
plt.hist(sample.numpy(), bins=500, range=(-1, 1))
plt.show()

image



Thank you again for your continuous support and hard work on maintaining the torchrl library.

vmoens commented 4 months ago

Hello Thanks for reporting this

2198 should fix it and improve the API.

Note that properly finding the mode of the distribution requires to find its maximum (well, the mode is the maximum haha) but there is not analytical expression. For the mean I implemented it using a regular stochastic expectation, but for the maximum i had to rely on Newton-Raphson which is considerably slower than what we had before, and non-differentiale.

Basically, before it was fast, now it's accurate

image

vmoens commented 4 months ago

image

import torch
from torchrl.modules import TanhNormal
import matplotlib.pyplot as plt

torch.random.manual_seed(0)

loc = torch.tensor([0.2], dtype=torch.float32)
scale = torch.tensor([1.0], dtype=torch.float32)

dist = TanhNormal(loc, scale, min=-1, max=1)
print("mode: ", dist.mode.item())  # mode:  0.1973753273487091

sample = dist.sample_n(100000)
plt.hist(sample.numpy(), bins=64, range=(-1, 1))
plt.show()

Now results in 1.0

Emile-Aquila commented 4 months ago

Hello

Thank you for your prompt response and for fixing the implementation error in the mode of TanhNormal distribution calculations. While the calculations have become slower, I understand that this is a necessary trade-off.

I really appreciate it.

vmoens commented 4 months ago

Given the time it takes to compute the mode now, wouldn't it make sense to make it a method (not a property) and raise a deprecation warning in the current property? Intuitively i think most users expect a property to be "fast". I'm mostly worried about runtime of common algos if we adopt this new property - whereas we could redirect users towards an alternative solution if they want to use this.

vmoens commented 4 months ago

I updated the PR to rely on Adam which is faster and more accurate than LBFGS, SGD and Newton Raphson (crazy right?)

Emile-Aquila commented 4 months ago

Does this mean we should implement the calculation of the correct mode as a separate method and revert the mode property to its original implementation? Now that you mentioned it, that seem more practical for me.Sorry for my confusing issue.

vmoens commented 4 months ago

yes if you look at the current status of the PR this is how I went about it

vmoens commented 4 months ago

Fixed by #2198