Open smorad opened 1 year ago
Thanks for raising this Seems like an interesting problem to tackle.
Regarding basic implementation, I think we should keep a contiguous stack of params and use vmap over the Q-Value network, that should be faster than looping over the networks / params.
Next you're raising the question of how to select or not the min value. We do that in a very custom fashion in the losses, but I agree that we should find a more generic way of doing so. If we get a table of state-action values like
[[v_00, v_01],
[v_10, v_11]]
for row-based actions 0, 1 and column-based networks 0, 1, how do you select an action using the min
operator (knowing that you want the action with the maximum value)?
IMO we could build the backbone that produces the table and append an action selection strategy like we do for QValue networks (i'm trying to move away from wrappers as wrapping quickly puts things in super nested structures where you don't really know where your original module lives). So we'd have something like
policy = TensorDIctSequential(
EnsembleStateActionValue(...), # uses vmap, writes a table of values in the output tensordict
EnsembleQValueActor(...), # selects the action given some heuristic
)
In the loss function, we can either pass the entire policy, which will work since TensorDIctSequential keeps track of intermediate values in the output tensordict (but risky: one could overwrite the actions in the tensordict) or just pass the EnsembleStateActionValue
and let the loss deal with it.
Happy to sketch a solution in a notebook if that helps!
Do you think it makes sense to break this into three modules, so we can utilize the default QValueActor
?
policy = TensorDictSequential(
Ensemble(in_keys=['observation'], out_keys=['ensemble_state_action_value']),
Reduce(in_keys=['ensemble_state_action_value'], out_keys=['state_value_action'] reduce_fn=lambda x, dim: x.min(dim=dim),
QValueActor(env.action_spec)
)
This keeps ensemble/reduce more general as they could be useful outside of Q functions.
I think that could work
Here's how I would go about the vmap module
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase, TensorDictModule
from torch import nn
import torch
net = nn.Sequential(nn.Linear(10, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 128))
module = TensorDictModule(net, in_keys=["in"], out_keys=["out"])
class VmapParamModule(TensorDictModuleBase):
def __init__(self, module, num_copies):
super().__init__()
params = TensorDict.from_module(module)
params = params.expand(num_copies).to_tensordict()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
self.params_td = params
self.params = nn.ParameterList(list(params.values(True, True)))
self.module = module
def forward(self, td):
return torch.vmap(self.module, (None, 0))(td, self.params_td)
vmap_module = VmapParamModule(module, 2)
td = TensorDict({"in": torch.randn(10)}, [])
vmap_module(td)
This gives you a td:
TensorDict(
fields={
in: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.float32, is_shared=False),
out: Tensor(shape=torch.Size([2, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
Great thanks. One last thing: I'm struggling with a nice way to reinitialize the copied parameters. Calling reset_parameters
on each module individually would be ideal, but not all torch modules have this. For example, nn.Sequential
does not expose reset_parameters
.
Furthermore, by disconnecting the modules from the parameters via TensorDict.from_module
, we can no longer call reset_parameters
on each copy of the parameters. We could pass parameter_init_function
to VMapParamModule
, but this will become ugly fast. For example, consider a MLP with a special final layer init. The linear bias init requires fan_in
which depends on the associated weight
. The final layer of the MLP should do something like nn.init.normal_(weight, 0, 1e-4), bias.zero_()
. So now we need to workout how to associate a weight
with a bias
and we also need to figure out which weight
and bias
belong to the final layer. I don't even wanna think about adding a CNN to the mix.
It's ugly, but perhaps something like the following is the best we can do?
modules = [deepcopy(module) for _ in range(num_copies)]
[user_defined_reset_weights_fn_(m) for m in modules]
params_td = TensorDict({f"copy_{k}": TensorDict.from_module(modules[k]) for k in range(num_copies)})
Another solution would be to add a recursive reset_parameters
to TensorDictModuleBase
, perhaps this would be useful elsewhere in torchrl? Then, we could reinitialize CNN/LSTM/MLP TensorDictModule parameters without having to write different code for each TensorDictModule
.
class TensorDictModuleBase:
...
def reset_parameters(self):
self._reset_parameters(self)
def _reset_parameters(self, module: Union[TensorDictModuleBase, nn.Module]):
if isinstance(module, [TensorDictModuleBase, nn.Module]):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
else:
[self._reset_parameters(m) for m in module.children()]
Motivation
Twin Q/ensemble Q functions are used in many RL algorithms and mitigate Q overestimation. My understanding is that TorchRL only deals with ensembles in the loss functions. This is fine for actor/critic methods since we only use the critics to compute actor loss. But for critic-only methods (e.g. DQN), we need the Q ensemble at sample collection time. Doing so would also simplify the loss functions for DDPG/SAC/REDQ/etc.
Solution
I would like to add ensemble Q function support to TorchRL, but I'm not sure on the best way to do this. I was thinking of creating a
TensorDictModuleEnsemble
intensordict_module.py
that could be used for more than just Q functions. The issue is that we essentially need twoforward
functions: one at sample-time to compute some reduce operation likemin
over the ensemble outputs, and one at training time that does not reduce, but rather does something likeso we can compute the loss for all Q functions. I'm not sure if there is a good way to tell a
TensorDictModule
whether it is in "sampling" or "training" mode.I think it also makes sense to provide an option to keep separate datasets for each model/Q function, e.g.
I'd like to avoid the for-loop here if possible, but I'm not sure how.
Additional context
Related to https://github.com/pytorch/rl/issues/876
Checklist