pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.36k stars 312 forks source link

[Feature Request] Purely functional loss objectives #338

Open XuehaiPan opened 2 years ago

XuehaiPan commented 2 years ago

Motivation

1. Consistent style for torch.nn.modules.loss.*Loss

In torch.nn.modules.loss, there are many *Loss subclassing nn.Module. The Loss.__init__() does not takes other nn.Module's as arguments. And method Loss.forward() method is purely functional and directly calls nn.functional.*_loss.

I think the motivation for using torch.nn.modules.loss.*Loss is compositing networks by nn.Sequential(...).

2. More straightforward implementation for functional style algorithms, such as meta-RL algorithms

In many meta-RL algorithms, the policy is trained with meta-parameters that may not register to the LossModule.

Case.1 MGRL: Register leaf meta-parameters as buffers in the loss module

For Meta-Gradient Reinforcement Learning (MGRL) https://arxiv.org/abs/1805.09801, it takes the discount factor gamma as the meta-parameter cross RL updates.

Use PPO for example:

import torch
import torch.nn as nn

from torchrl.objectives import PPOLoss

import torchopt

### Setup ###
meta_param = nn.Parameter(torch.tensor(0.95))
loss_module = PPOLoss(
    actor, critic, ...,
    gamma=None,  # whatever value
)
loss_module.regester_buffer('gamma', meta_param)  # register gamma as buffer

### Optimizers ###
inner_optim = torchopt.MetaAdam(loss_module)
outer_optim = torchopt.Adam([meta_param])

### Inner update (update network parameters) ####
inner_loss1 = loss_module(tensordict)
inner_optim.step(inner_loss1['loss_objective'])  # inner update 1: param(0) -> param(1)

...

inner_lossN = loss_module(tensordict)
inner_optim.step(inner_lossN['loss_objective'])  # inner update N: param(N - 1) -> param(N)

### Outer update (update meta-parameters (gamma)) ###
outer_loss = loss_module(tensordict)  # sampled by param(N)
outer_optim.zero_grad()
outer_loss.backward()
outer_optim.step()

See https://github.com/metaopt/TorchOpt#torchopt-as-differentiable-optimizer-for-meta-learning for figures.

we need to register our meta-parameter gamma in the buffer of the loss module instead of full control of the parameters by the user.

For integration with functorch, register the meta-parameter as module buffer works freely.

meta_param = nn.Parameter(torch.tensor(0.95))
loss_module = PPOLoss(
    actor, critic, ...,
    gamma=None,  # whatever value
)
loss_module.regester_buffer('gamma', meta_param)  # register gamma as buffer

# Make functional
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)

Case.2 LPG: Register non-leaf meta-parameters as buffers in the loss module on every outer update

For Learning Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, it takes the LSTM network as the meta-parameter.

Different from MGRL, on each update, the meta-network output is not a leaf tensor anymore. Then we need to register these output again and again before each call of loss_module.forward. This makes

fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)

not working.

cc @Benjamin-eecs @waterhorse1

Solution

A clear and concise description of what you want to happen.

Split the forward method in the loss module into a separate pure function, i.e., a state less function does not have any parameters. The model parameters should be organized by other modules. The loss function only takes a tensordict as input, and add a new key "loss_objective" into the tensordict. All tensor inputs (e.g. value = self.critic(...)) should be calculated before calling the loss function, because the loss function is purely functional, i.e., does not host parameters (e.g., actor.parameters(), critic.parameters()).

Here is a prototype example:

def ppo_loss(tensordict: TensorDictBase, dist_class: Distribution, **kwargs) -> TensorDictBase:
    tensordict = tensordict.clone()
    gamma = tensordict.get('gamma', kwargs.get('gamma'))  # In MGRL, the `gamma` parameter can be a tensor rather than a Python scalar
    critic_coeff = tensordict.get('critic_coeff', kwargs.get('critic_coeff'))
    dist = dist_class(tensordict.get('actor_output'))
    ...

For backward compatibility, refactor the PPOLoss module as:

class PPOLoss(LossModule):
    actor: nn.Module
    critic: nn.Module
    gamma: float
    entropy_coef: float
    critic_coef: float
    advantage_module: nn.Module

    def __init__(self, ...):
        ...

        self.floss_kwargs = dict(
            gamma=self.gamma,
            entropy_coef=self.entropy_coef,
            critic_coef=self.critic_coef,
        )

    def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
        # Prepare all inputs for loss function
        if self.advantage_module is not None:
            tensordict = self.advantage_module(
                tensordict,
            )
        tensordict = tensordict.clone()
        ...

        # Call purely functional version of loss function
        return ppo_loss(
            tensordict,
            dist_class,
            **self.floss_kwargs
        )

Alternatives

A clear and concise description of any alternative solutions or features you've considered.

Copy and paste the loss module source code, then do specific customizations.

Additional context

Add any other context or screenshots about the feature request here.

Checklist

vmoens commented 2 years ago

Thans for this detailed description @XuehaiPan

Can you elaborate on this?

Case.2 LPG: Register non-leaf meta-parameters as buffers in the loss module on every outer update

For Learning Policy Gradient (LPG) https://arxiv.org/abs/1805.09801, it takes the LSTM network as the meta-parameter.

Different from MGRL, on each update, the meta-network output is not a leaf tensor anymore. Then we need to register these output again and again before each call of loss_module.forward. This makes

fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)

not working.

As far as I understand it, the parameters of the LSTM (if they're part of the model) will end up in params when calling functorch.make_functional_with_buffers and they won't need to be leaf tensors anymore.

But I guess I misunderstood your point.

XuehaiPan commented 2 years ago

@vmoens Sorry for the late reply and wrong link for the LPG algorithm.

Background: Policy Gradient Theorem (Ref: High-Dimensional Continuous Control Using Generalized Advantage Estimation (GAE))

image

(Using screenshot here due to poor inline math rendering on GitHub)

The main contribution of the GAE paper is introducing a new rule:

$$ \Psit = \operatorname{GAE} (r{0:t}, \gamma, \lambda) $$

GAE is calculated by a carefully designed rule, which is fixed.


For Learned Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, a separate LSTM network (parameterized by $\phi$) is learned to fit the rule to get $\Psi_t$ in equation (1) above. I.e.:

$$ \Psit = \operatorname{LSTM} (r{0:t}; \phi) $$

Then the policy-gradient-based RL policy (parameterized by $\theta$) uses the output $\Psi_t$ to achieve the underlining task.

The use case is:

  1. We have multiple environments ( $n$ ) to learn, which means we have $n$ separate policy networks (parameterized by $\theta_1, \dots, \theta_n$). The parameters are not shared.
  2. A separate LSTM network (parameterized by $\phi$) is shared for all tasks.

Note that the parameters $\phi$ are shared for multiple LossModule, then we need to register the LSTM network as the sub-module or buffer of each LossModule. This can be a large number of tensors rather than one in the MGRL example (only one tensor, the gamma). Then we make the LossModule functional using functorch. It's hard to distinguish which param is the meta-parameter from the LSTM and which are the parameters of the RL policy:

fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)

We need to maintain indices for this and pass the parameters carefully:

params = ...  # manipulate the parameters carefully
loss = fmodel(params, buffers, batch)

In the feature request, we are requesting to calculate all necessary data before calling the loss function rather than inside the function. Then we can do:

meta_fmodel, meta_params, meta_buffers = functorch.make_functional_with_buffers(meta_module)
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)

batch = meta_fmodel(meta_params, meta_buffers, batch)
batch = fmodel(params, buffers, batch)
loss = loss_fn(batch)
vmoens commented 2 years ago

Thanks for the answer

We need to maintain indices for this and pass the parameters carefully.

If I understand well what bothers you with the current API is that with functorch as it is not, you are working with a list of parameters/buffers which makes it hard to assign them to a particular module, is that right?

What if we were storing the params in a dictionary instead (or better: a TensorDict :D )? @zou3519 mentioned there that some users were asking for this. I'm happy to sketch a solution like this if that would unblock you (and/or implement the solution you were suggesting, even temporarily while waiting for a functorch fix if we opt for it).

XuehaiPan commented 2 years ago

If I understand well what bothers you with the current API is that with functorch as it is not, you are working with a list of parameters/buffers which makes it hard to assign them to a particular module, is that right?

Yes, that's right.

and/or implement the solution you were suggesting

Thanks for this. The new implementation would resolve both points 1 and 2 in https://github.com/facebookresearch/rl/issues/338#issue-1328464480.