pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

[Use Case] Natural-gradient Reinforcement Learning #733

Open DarkbyteAT opened 2 years ago

DarkbyteAT commented 2 years ago

Description

For my thesis project, I'm applying a novel Polyak-averaging approach to various reinforcement learning algorithms; the approach uses natural-gradient descent in order to estimate the effective momentum for policy-network parameters.

Current policy-gradient methods don't fare well with momentum as training occurs over non-stationary distributions, hence the momentum calculated is not invariant on the order that examples are shown - making momentum-based methods such as Adam unworkable. AdamW shows more promise however still seems to fail, and Root Mean Square Propagation is currently the baseline within research, however we know that using natural-gradient optimisation (sometimes along with these methods as shown in the example at the bottom) can drastically improve performance and sample efficiency, as they can guarantee monotonic improvement of network-parameterised policies: their main issue falls in computational efficiency.

Natural-gradient optimisation methods within reinforcement learning traditionally rely on either Hessian-free approaches such as TRPO, computing Hessian estimates through Jacobian matrices multiplied by their transpositions, expensive direct-computation of Hessians through traditional routes, or more recently, computation of Kroenecker-factorised approximations such as KFAC or E-KFAC (and very recently MFAC).

Existing Solutions

This has been implemented in pytorch through the KFAC-PyTorch library, as well as the nngeometry library and associated paper, however this paper has a few key drawbacks that discount its usage as a "plug-and-play" solution - such as the one provided by functorch:

Drawbacks to Existing Solutions

Proposed Additions

The functorch library meanwhile appears to provide a direct, simple way to do the expensive part of the computation - i.e. compute the Hessian matrices - and seems to be faster with its outer-loop optimisations speeding up the process vastly. The two things missing from this library are:

I believe that a great deal of interest could be sparked in both this library, and more importantly trigger a resurgence in interest for higher-order optimisation methods which have time-and-time-again proven to be faster at converging than first-order methods; as the things that seem to be holding it back are hesitancies in research due to the points mentioned above; I nearly gave up on months of project work because of the issues myself before finding this paper/library that made the process so much easier. The headaches of workarounds still stood, but having some standardised function for a method to do this would make this library go from a 10/10 to an 11.

Simplified Example Usage

import gym
import numpy as np
import torch as T

from functorch import jacrev
from torch import distributions as distr, nn as nn, optim as opt
...

class ExampleNet(nn.Module):
    def __init__(self, submodule1: nn.Module, ...):
        super().__init__(self)
        self.submodule1 = submodule1
        ...

env = gym.make("CartPole-v1")
net = ExampleNet(nn.Linear(input_features, hidden_dim1), nn.ReLU(), nn.Linear(hidden_dim1, hidden_dim2), ...)
optim = opt.RMSprop()
...

# run episode stuff and collect/process data here, loops removed for simplicity/readability
...

visited_states = np.asarray(...)
recorded_actions = np.asarray(...)
collected_rewards = np.asarray(...)
...

observations = T.from_numpy(visited_states, ..., requires_grad=True)
actions = T.from_numpy(recorded_actions, ...)
rewards = T.from_numpy(collected_rewards, ..., requires_grad=True)
...

# require functions to use parameters with a given form, maybe order / names / starting argnums?
# there is probably a better way, this is just an example
def log_probs(net_outs: tuple[Tensor, ...], actions: T.Tensor, ...) -> T.Tensor:
    return distr.Categorical(net_outs[0]).log_prob(actions)
...

# use method to find change in log probabilities w.r.t network parameters theta
# could also create a new method suffixed by '_net' for simplicity?
# 
# example params to add:
#     * wrt_net: bool (default=False) - if True, returned function accepts parameters:
#         + net: nn.Module - network to find gradient w.r.t
#         + net_ins: tuple | list of Tensor - inputs to network as net(*net_ins) or net.forward(*net_ins) 
#                                                            given to function's 'net_outs' param.
#         + recursive: bool (default=False) - should be w.r.t net.parameters(recursive)
#         + **kwargs: Tensor - could be used to specify function parameters by argument name

jac_fn = vmap(jacrev(log_probs, ..., wrt_net=True))
log_probs, jac_theta = jac_fn(net, net_ins=(observations,), recursive=True, actions=actions)

# process jacobian into appropriate NxN matrix for N samples
...

loss = T.sum(jac_theta * -log_probs * rewards)
optim.zero_grad()
loss.backward()
optim.step()
zou3519 commented 2 years ago

Thanks for the detailed issue, @DarkbyteAT!

cc @vmoens for discussion around RL.

Implementations of higher-order approximations for common machine-learning methods.

To clarify, which higher-order approximations are you interested in?

Standardised way to find both per-sample and batched higher-order functions w.r.t network parameters. Currently this requires hacky, performance-impacting workarounds involving creating copies of models to compute properly.

@DarkbyteAT, is your the claim the following:

We are indeed interested in a more natural API that works directly on PyTorch NN modules but it hasn't been clear to us what that would look like (cc @Chillee). Your opinions here are very helpful though!

One question I have here is: what exactly does

jac_fn = vmap(jacrev(log_probs, ..., wrt_net=True))
log_probs, jac_theta = jac_fn(net, net_ins=(observations,), recursive=True, actions=actions)

return for jac_theta? We're computing the jacobian with respect to the parameters in the network. Does this return:

ain-soph commented 2 years ago

@DarkbyteAT As far as I know, KFAC and EKFAC are both methods to approximate inverse of Fisher Information Matrix.

Natural Gradients is the corrected gradient in Riemannian manifolds based on the new distance concept of Jensen-Shannon divergence. Its metric tensor is Fisher Information Matrix (FIM), so the corrected gradient is left-multiplied by the inverse of FIM (See https://en.wikipedia.org/wiki/Gradient#Riemannian_manifolds).
However, It's memory-costly and time-costly to calculate the whole FIM and its inverse. KFAC and EKFAC papers proposes the memory-efficient way without computing FIM explicitly by using following approximation (in paper):

image

I also implemented KFAC and EKFAC implementation based on their codes (their implementations are too old):
https://ain-soph.github.io/trojanzoo/trojanzoo/utils/fim.html

But if you do really want to calculate the full-scale FIM for different layers, I think following script should work (without using functorch though. You may avoid usage of torch.nn.utils._stateless.functional_call by using functorch).

It's impossible to calculate FIM for all layers together, so I have to calculate FIM for each layer separately.

import torch
import torch.nn as nn
from torch.nn.utils import _stateless

def fim(module: nn.Module, _input: torch.Tensor,
            parameters: dict[str, nn.Parameter] = None
            ) -> list[torch.Tensor]:
    if parameters is None:
        parameters = dict(module.named_parameters())
    keys, values = zip(*parameters.items())

    with torch.no_grad():
        _output: torch.Tensor = module(_input)  # (N, C)
        prob = _output.softmax(dim=1).unsqueeze(-1).unsqueeze(-1)  # (N, C, 1, 1)

    def func(*params: torch.Tensor):
        _output: torch.Tensor = _stateless.functional_call(
            module, {n: p for n, p in zip(keys, params)}, _input)
        return _output.log_softmax(dim=1)  # (N, C)
    jac: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        func, values)
    jac = (j.flatten(2) for j in jac)   # (N, C, D)

    fim_list: list[torch.Tensor] = []
    for j in jac:  # TODO: parallel
        fim = prob * j.unsqueeze(-1) * \
            j.unsqueeze(-2)  # (N, C, D, D)
        fim_list.append(fim.sum(1).mean(0))   # (D, D)
    return fim_list