pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.31k stars 306 forks source link

[BUG] Model buffers updated in the loss are not reflected outside the loss #1702

Closed matteobettini closed 11 months ago

matteobettini commented 11 months ago

Describe the bug

When a model is passed to a loss, and a buffer is updated on that model, the update is not visible on the passed model. This works if the buffer is updated in-place

To Reproduce

import torch
from tensordict import TensorDict
from torch import nn
from torchrl.modules import QValueActor
from torchrl.objectives import DQNLoss

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 1)
        self.register_buffer("buffer_a", torch.tensor([0.0]))

    def forward(self, x):
        if torch.is_grad_enabled():
            self.buffer_a = torch.tensor([1.0])
        print(f"Buffer value: {self.buffer_a}")
        return self.linear(x)

if __name__ == "__main__":
    module = Model()
    actor = QValueActor(module=module, action_space="one_hot")
    loss = DQNLoss(value_network=actor, action_space="one_hot")
    td = TensorDict(
        {
            "observation": torch.zeros(5),
            "next": TensorDict(
                {
                    "observation": torch.zeros(5),
                    "reward": torch.zeros(1),
                    "done": torch.zeros(1, dtype=torch.bool),
                },
                [],
            ),
        },
        [],
    )

    print("Forward the model for collection")
    with torch.no_grad():
        actor.forward(td)
    print("Forward the model for training")
    loss.forward(td)
    print("Re-forward the model for collection")
    with torch.no_grad():
        actor.forward(td)
Forward the model for collection
Buffer value: tensor([0.])
Forward the model for training
Buffer value: tensor([1.])
Buffer value: tensor([1.])
Re-forward the model for collection
Buffer value: tensor([0.])
vmoens commented 11 months ago

This

        if torch.is_grad_enabled():
            self.buffer_a = torch.tensor([1.0])

overrides the buffer and makes it a regular tensor, have you tried modifying it in-place?

vmoens commented 11 months ago

Ah nevermind it seems the buffer is kept in this case, let me have a deeper look

matteobettini commented 11 months ago

Yes, as per description, in-place works. But for example self.buffer_a = self.buffer_a + 1 does not

vmoens commented 11 months ago

Got it so the reason is that currently if you change the identity of a param or buffer during a functional call this isn't reflected in the tensordict that was passed. It will be fixed once we use the new functional API

vmoens commented 11 months ago

I'm doing the refactoring of the losses. One thing that makes this impossible currently is that we're using the same module for target and regular calls, just swapping the parameters and buffers using a functional call. Hence, we'll have multiple calls to the module during a single forward and it's unclear which one should affect the original module (one could expect that the module is only changed by the calls made with regular params, not target, but it doesn't seem as clear as that in general).

I think the answer will be that users should either do the changes in-place or re-sync the module params using some custom

loss_module(data)
loss_module.actor_params.to_module(policy_module)

which will update the policy_module attributes with the updated params and buffers. This solution will support non inplace modifications of buffers.

matteobettini commented 11 months ago

One thing that makes this impossible currently is that we're using the same module for target and regular calls, just swapping the parameters and buffers using a functional call. Hence, we'll have multiple calls to the module during a single forward and it's unclear which one should affect the original module (one could expect that the module is only changed by the calls made with regular params, not target, but it doesn't seem as clear as that in general).

Why do we care which one affects the buffers? Wouldn't we just need to reflect any change in the buffers as soon as it happens (independently from the context of the call)? For example, in my script the user changes the buffer only if torch.is_grad_enabled():, thus making the change only in the non-target case.

loss_module.actor_params.to_module(policy_module) This confuses me a bit: are the buffers considered params? Are we creating target buffers? In my opinion buffers should not be included with the targets or params and should be shared at all times among all the instances of the model (target, non-target, collection)

matteobettini commented 11 months ago

Also, would in-place modification work when collection and loss models are sitting on different devices?

vmoens commented 11 months ago

Why do we care which one affects the buffers?

I don't get that comment, can you elaborate?

Wouldn't we just need to reflect any change in the buffers as soon as it happens (independently from the context of the call)? For example, in my script the user changes the buffer only if torch.is_grad_enabled():, thus making the change only in the non-target case.

This is a custom logic from your module that we can't capture in the loss module (we can't read through the module and detect the calls to is_grad_enabled()). Also it isn't given that computations with target params have grad disabled (eg meta-rl and many other cases where you still build a graph with target params). I'm dubious regarding any capture of the whether or not the current op is part of a graph as a detection of inference vs training, there are several cases of gradient computation required during inference and similarly there are many cases of gradient disabled during training. I think this shouldn't be used in RL (maybe I misunderstood the example though).

Per se your params and potentially target params will be affected, but within the loss module.

For info, we need to do a deep copy of the stateless module before putting it in the loss module because we store the parameters independently from the module itself. If we don't do the deep copy, we will have duplicates of the params when calling loss_module.parameters(). Given this, any modification of the module itself will be lost and I cannot think of a way of avoiding that except not using functional parameters, which is a no-go atm.

This confuses me a bit: are the buffers considered params? Are we creating target buffers? In my opinion buffers should not be included with the targets or params and should be shared at all times among all the instances of the model (target, non-target, collection)

That's controversial :) Take batch-norm: if you're using smth like that, your buffers are synced with the params and should not be shared. I think this is the general case, ie. buffers and params are synced.

matteobettini commented 11 months ago

Ok so what I understand from this is that you are saying that the only possibility for buffers is to be part of the parameters when the loss makes the module functional.

And given this, the only thing we can do is update them in-place or have some code to copy them over.

But if we update them in-place we would still need the code to copy them (something like collector._update_policy_weights()) because the 2 models could be on different devices.

matteobettini commented 11 months ago

That's controversial :) Take batch-norm: if you're using smth like that, your buffers are synced with the params and should not be shared. I think this is the general case, ie. buffers and params are synced.

So the batch-norm buffers are part of the target params and are updated with the same strategy (hard/soft)? Is this how it is normally done?

vmoens commented 11 months ago

Not sure what normally refers to, but yes this is how it is done atm

vmoens commented 11 months ago

This is now handled as best we can do with functional calls, closing as "not planned" because I'm not sure how that would unfold when using target networks, metalearning etc.