Open XuehaiPan opened 2 years ago
Thans for this detailed description @XuehaiPan
Can you elaborate on this?
Case.2 LPG: Register non-leaf meta-parameters as buffers in the loss module on every outer update
For Learning Policy Gradient (LPG) https://arxiv.org/abs/1805.09801, it takes the LSTM network as the meta-parameter.
Different from MGRL, on each update, the meta-network output is not a leaf tensor anymore. Then we need to register these output again and again before each call of
loss_module.forward
. This makesfmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)
not working.
As far as I understand it, the parameters of the LSTM (if they're part of the model) will end up in params
when calling functorch.make_functional_with_buffers
and they won't need to be leaf tensors anymore.
But I guess I misunderstood your point.
@vmoens Sorry for the late reply and wrong link for the LPG algorithm.
Background: Policy Gradient Theorem (Ref: High-Dimensional Continuous Control Using Generalized Advantage Estimation (GAE))
(Using screenshot here due to poor inline math rendering on GitHub)
The main contribution of the GAE paper is introducing a new rule:
$$ \Psit = \operatorname{GAE} (r{0:t}, \gamma, \lambda) $$
GAE is calculated by a carefully designed rule, which is fixed.
For Learned Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, a separate LSTM network (parameterized by $\phi$) is learned to fit the rule to get $\Psi_t$ in equation (1) above. I.e.:
$$ \Psit = \operatorname{LSTM} (r{0:t}; \phi) $$
Then the policy-gradient-based RL policy (parameterized by $\theta$) uses the output $\Psi_t$ to achieve the underlining task.
The use case is:
Note that the parameters $\phi$ are shared for multiple LossModule
, then we need to register the LSTM network as the sub-module or buffer of each LossModule
. This can be a large number of tensors rather than one in the MGRL example (only one tensor, the gamma
). Then we make the LossModule
functional using functorch
. It's hard to distinguish which param is the meta-parameter from the LSTM and which are the parameters of the RL policy:
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)
We need to maintain indices for this and pass the parameters carefully:
params = ... # manipulate the parameters carefully
loss = fmodel(params, buffers, batch)
In the feature request, we are requesting to calculate all necessary data before calling the loss function rather than inside the function. Then we can do:
meta_fmodel, meta_params, meta_buffers = functorch.make_functional_with_buffers(meta_module)
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)
batch = meta_fmodel(meta_params, meta_buffers, batch)
batch = fmodel(params, buffers, batch)
loss = loss_fn(batch)
Thanks for the answer
We need to maintain indices for this and pass the parameters carefully.
If I understand well what bothers you with the current API is that with functorch as it is not, you are working with a list of parameters/buffers which makes it hard to assign them to a particular module, is that right?
What if we were storing the params in a dictionary instead (or better: a TensorDict :D )? @zou3519 mentioned there that some users were asking for this. I'm happy to sketch a solution like this if that would unblock you (and/or implement the solution you were suggesting, even temporarily while waiting for a functorch fix if we opt for it).
If I understand well what bothers you with the current API is that with functorch as it is not, you are working with a list of parameters/buffers which makes it hard to assign them to a particular module, is that right?
Yes, that's right.
and/or implement the solution you were suggesting
Thanks for this. The new implementation would resolve both points 1 and 2 in https://github.com/facebookresearch/rl/issues/338#issue-1328464480.
Motivation
1. Consistent style for
torch.nn.modules.loss.*Loss
In
torch.nn.modules.loss
, there are many*Loss
subclassingnn.Module
. TheLoss.__init__()
does not takes othernn.Module
's as arguments. And methodLoss.forward()
method is purely functional and directly callsnn.functional.*_loss
.I think the motivation for using
torch.nn.modules.loss.*Loss
is compositing networks bynn.Sequential(...)
.2. More straightforward implementation for functional style algorithms, such as meta-RL algorithms
In many meta-RL algorithms, the policy is trained with meta-parameters that may not register to the
LossModule
.Case.1 MGRL: Register leaf meta-parameters as buffers in the loss module
For Meta-Gradient Reinforcement Learning (MGRL) https://arxiv.org/abs/1805.09801, it takes the discount factor
gamma
as the meta-parameter cross RL updates.Use PPO for example:
See https://github.com/metaopt/TorchOpt#torchopt-as-differentiable-optimizer-for-meta-learning for figures.
we need to register our meta-parameter
gamma
in the buffer of the loss module instead of full control of the parameters by the user.For integration with
functorch
, register the meta-parameter as module buffer works freely.Case.2 LPG: Register non-leaf meta-parameters as buffers in the loss module on every outer update
For Learning Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, it takes the LSTM network as the meta-parameter.
Different from MGRL, on each update, the meta-network output is not a leaf tensor anymore. Then we need to register these output again and again before each call of
loss_module.forward
. This makesnot working.
cc @Benjamin-eecs @waterhorse1
Solution
A clear and concise description of what you want to happen.
Split the
forward
method in the loss module into a separate pure function, i.e., a state less function does not have any parameters. The model parameters should be organized by other modules. The loss function only takes atensordict
as input, and add a new key"loss_objective"
into thetensordict
. All tensor inputs (e.g.value = self.critic(...)
) should be calculated before calling the loss function, because the loss function is purely functional, i.e., does not host parameters (e.g.,actor.parameters()
,critic.parameters()
).Here is a prototype example:
For backward compatibility, refactor the
PPOLoss
module as:Alternatives
A clear and concise description of any alternative solutions or features you've considered.
Copy and paste the loss module source code, then do specific customizations.
Additional context
Add any other context or screenshots about the feature request here.
Checklist