pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.01k stars 269 forks source link

[BUG] NoisyLinear's noise is constant #1963

Open Shmuma opened 4 months ago

Shmuma commented 4 months ago

Describe the bug

In paper "Noisy Networks for exploration" they say (section 3.1) "A noisy network agent samples a new set of parameters after every step of optimisation. "

But in current implementation, epsilon values are initialized only on layer's creation and remain constant. To get the paper's behaviour, user has to call reset_noise() method explicitly which is not mentioned in the documentation.

I propose either mention this in documentation or (better from my perspective), reset noise after every forward pass of the layer.

To Reproduce

>>> from torchrl.modules import NoisyLinear
>>> l = NoisyLinear(2, 3)
>>> l.weight_epsilon
tensor([[-0.0623,  0.0107],
        [ 0.7793, -0.1333],
        [-1.7952,  0.3072]])
>>> opt = torch.optim.Adam(l.parameters(), lr=0.01)
>>> t = l(torch.ones(2))
>>> t
tensor([ 0.3981, -0.6741, -0.3556], grad_fn=<ViewBackward0>)
>>> t.sum().backward()
>>> opt.step()
>>> l.weight_epsilon
tensor([[-0.0623,  0.0107],
        [ 0.7793, -0.1333],
        [-1.7952,  0.3072]])
>>> l.reset_noise()
>>> l.weight_epsilon
tensor([[ 0.7776, -0.7971],
        [ 0.7615, -0.7806],
        [-0.7546,  0.7735]])

Expected behavior

Noise parameters got resampled after the training step.

System info

Using stable version, but main branch is affected as well.

Checklist

vmoens commented 4 months ago

Hey thanks for this! You're absolutely right that the doc should be clearer. The reason we don't do it at every forward is that we want to leave the user in control since this module also has to be used during inference (I suppose you don't want to reset the noise at every step in the data collection).

I'm actually working on #1587 which faces the same problem so we definitely need to think of a nice integrated way of doing this that does not compromise flexibility but is easy to use.

I will keep you posted on this

Shmuma commented 4 months ago

Thanks for the reply!

What do you think about using backward hook for the noisy parameters resampling? From my naive perspective looks like a proper way to execute code only when we're optimizing the layer.

I mean this method: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_pre_hook

vmoens commented 4 months ago

That could be the default behaviour, yes! I would leave an option for the user to deactivate that and take control over resetting. Also am I right to think hat the noise should not be synced with the policy used for inference? I think if you use your policy with a noisy linear in a collector the same noise will be used on both sides (because we sync both params and buffers).

Shmuma commented 4 months ago

Also am I right to think hat the noise should not be synced with the policy used for inference?

Good question, need to reread the paper.

From my understanding, the agent should be robust to both options - in the end we just learn how to deal with noisy component in our weights. So, it's up to the agent to decide how much and when it might use this randomness. From that perspective, resampling epsilon could potentially be useful even on inference.