Closed matteobettini closed 11 months ago
This
if torch.is_grad_enabled():
self.buffer_a = torch.tensor([1.0])
overrides the buffer and makes it a regular tensor, have you tried modifying it in-place?
Ah nevermind it seems the buffer is kept in this case, let me have a deeper look
Yes, as per description, in-place works.
But for example self.buffer_a = self.buffer_a + 1
does not
Got it so the reason is that currently if you change the identity of a param or buffer during a functional call this isn't reflected in the tensordict that was passed. It will be fixed once we use the new functional API
I'm doing the refactoring of the losses. One thing that makes this impossible currently is that we're using the same module for target and regular calls, just swapping the parameters and buffers using a functional call. Hence, we'll have multiple calls to the module during a single forward and it's unclear which one should affect the original module (one could expect that the module is only changed by the calls made with regular params, not target, but it doesn't seem as clear as that in general).
I think the answer will be that users should either do the changes in-place or re-sync the module params using some custom
loss_module(data)
loss_module.actor_params.to_module(policy_module)
which will update the policy_module
attributes with the updated params and buffers. This solution will support non inplace modifications of buffers.
One thing that makes this impossible currently is that we're using the same module for target and regular calls, just swapping the parameters and buffers using a functional call. Hence, we'll have multiple calls to the module during a single forward and it's unclear which one should affect the original module (one could expect that the module is only changed by the calls made with regular params, not target, but it doesn't seem as clear as that in general).
Why do we care which one affects the buffers?
Wouldn't we just need to reflect any change in the buffers as soon as it happens (independently from the context of the call)? For example, in my script the user changes the buffer only if torch.is_grad_enabled():
, thus making the change only in the non-target case.
loss_module.actor_params.to_module(policy_module)
This confuses me a bit: are the buffers considered params? Are we creating target buffers? In my opinion buffers should not be included with the targets or params and should be shared at all times among all the instances of the model (target, non-target, collection)
Also, would in-place modification work when collection and loss models are sitting on different devices?
Why do we care which one affects the buffers?
I don't get that comment, can you elaborate?
Wouldn't we just need to reflect any change in the buffers as soon as it happens (independently from the context of the call)? For example, in my script the user changes the buffer only if torch.is_grad_enabled():, thus making the change only in the non-target case.
This is a custom logic from your module that we can't capture in the loss module (we can't read through the module and detect the calls to is_grad_enabled()
). Also it isn't given that computations with target params have grad disabled (eg meta-rl and many other cases where you still build a graph with target params).
I'm dubious regarding any capture of the whether or not the current op is part of a graph as a detection of inference vs training, there are several cases of gradient computation required during inference and similarly there are many cases of gradient disabled during training. I think this shouldn't be used in RL (maybe I misunderstood the example though).
Per se your params and potentially target params will be affected, but within the loss module.
For info, we need to do a deep copy of the stateless module before putting it in the loss module because we store the parameters independently from the module itself. If we don't do the deep copy, we will have duplicates of the params when calling loss_module.parameters()
. Given this, any modification of the module itself will be lost and I cannot think of a way of avoiding that except not using functional parameters, which is a no-go atm.
This confuses me a bit: are the buffers considered params? Are we creating target buffers? In my opinion buffers should not be included with the targets or params and should be shared at all times among all the instances of the model (target, non-target, collection)
That's controversial :) Take batch-norm: if you're using smth like that, your buffers are synced with the params and should not be shared. I think this is the general case, ie. buffers and params are synced.
Ok so what I understand from this is that you are saying that the only possibility for buffers is to be part of the parameters when the loss makes the module functional.
And given this, the only thing we can do is update them in-place or have some code to copy them over.
But if we update them in-place we would still need the code to copy them (something like collector._update_policy_weights()
) because the 2 models could be on different devices.
That's controversial :) Take batch-norm: if you're using smth like that, your buffers are synced with the params and should not be shared. I think this is the general case, ie. buffers and params are synced.
So the batch-norm buffers are part of the target params and are updated with the same strategy (hard/soft)? Is this how it is normally done?
Not sure what normally refers to, but yes this is how it is done atm
This is now handled as best we can do with functional calls, closing as "not planned" because I'm not sure how that would unfold when using target networks, metalearning etc.
Describe the bug
When a model is passed to a loss, and a buffer is updated on that model, the update is not visible on the passed model. This works if the buffer is updated in-place
To Reproduce