Open DarkbyteAT opened 2 years ago
Thanks for the detailed issue, @DarkbyteAT!
cc @vmoens for discussion around RL.
Implementations of higher-order approximations for common machine-learning methods.
To clarify, which higher-order approximations are you interested in?
Standardised way to find both per-sample and batched higher-order functions w.r.t network parameters. Currently this requires hacky, performance-impacting workarounds involving creating copies of models to compute properly.
@DarkbyteAT, is your the claim the following:
We are indeed interested in a more natural API that works directly on PyTorch NN modules but it hasn't been clear to us what that would look like (cc @Chillee). Your opinions here are very helpful though!
One question I have here is: what exactly does
jac_fn = vmap(jacrev(log_probs, ..., wrt_net=True))
log_probs, jac_theta = jac_fn(net, net_ins=(observations,), recursive=True, actions=actions)
return for jac_theta? We're computing the jacobian with respect to the parameters in the network. Does this return:
@DarkbyteAT As far as I know, KFAC and EKFAC are both methods to approximate inverse of Fisher Information Matrix.
Natural Gradients is the corrected gradient in Riemannian manifolds based on the new distance concept of Jensen-Shannon divergence. Its metric tensor is Fisher Information Matrix (FIM), so the corrected gradient is left-multiplied by the inverse of FIM (See https://en.wikipedia.org/wiki/Gradient#Riemannian_manifolds).
However, It's memory-costly and time-costly to calculate the whole FIM and its inverse. KFAC and EKFAC papers proposes the memory-efficient way without computing FIM explicitly by using following approximation (in paper):
I also implemented KFAC and EKFAC implementation based on their codes (their implementations are too old):
https://ain-soph.github.io/trojanzoo/trojanzoo/utils/fim.html
But if you do really want to calculate the full-scale FIM for different layers, I think following script should work (without using functorch
though. You may avoid usage of torch.nn.utils._stateless.functional_call
by using functorch
).
It's impossible to calculate FIM for all layers together, so I have to calculate FIM for each layer separately.
import torch
import torch.nn as nn
from torch.nn.utils import _stateless
def fim(module: nn.Module, _input: torch.Tensor,
parameters: dict[str, nn.Parameter] = None
) -> list[torch.Tensor]:
if parameters is None:
parameters = dict(module.named_parameters())
keys, values = zip(*parameters.items())
with torch.no_grad():
_output: torch.Tensor = module(_input) # (N, C)
prob = _output.softmax(dim=1).unsqueeze(-1).unsqueeze(-1) # (N, C, 1, 1)
def func(*params: torch.Tensor):
_output: torch.Tensor = _stateless.functional_call(
module, {n: p for n, p in zip(keys, params)}, _input)
return _output.log_softmax(dim=1) # (N, C)
jac: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
func, values)
jac = (j.flatten(2) for j in jac) # (N, C, D)
fim_list: list[torch.Tensor] = []
for j in jac: # TODO: parallel
fim = prob * j.unsqueeze(-1) * \
j.unsqueeze(-2) # (N, C, D, D)
fim_list.append(fim.sum(1).mean(0)) # (D, D)
return fim_list
Description
For my thesis project, I'm applying a novel Polyak-averaging approach to various reinforcement learning algorithms; the approach uses natural-gradient descent in order to estimate the effective momentum for policy-network parameters.
Current policy-gradient methods don't fare well with momentum as training occurs over non-stationary distributions, hence the momentum calculated is not invariant on the order that examples are shown - making momentum-based methods such as Adam unworkable. AdamW shows more promise however still seems to fail, and Root Mean Square Propagation is currently the baseline within research, however we know that using natural-gradient optimisation (sometimes along with these methods as shown in the example at the bottom) can drastically improve performance and sample efficiency, as they can guarantee monotonic improvement of network-parameterised policies: their main issue falls in computational efficiency.
Natural-gradient optimisation methods within reinforcement learning traditionally rely on either Hessian-free approaches such as TRPO, computing Hessian estimates through Jacobian matrices multiplied by their transpositions, expensive direct-computation of Hessians through traditional routes, or more recently, computation of Kroenecker-factorised approximations such as KFAC or E-KFAC (and very recently MFAC).
Existing Solutions
This has been implemented in
pytorch
through theKFAC-PyTorch
library, as well as thenngeometry
library and associated paper, however this paper has a few key drawbacks that discount its usage as a "plug-and-play" solution - such as the one provided byfunctorch
:Drawbacks to Existing Solutions
KFAC-PyTorch
library only provides optimisers that implement KFAC and E-KFAC for natural-gradient descent and ascent; one of the approaches used in papers for naturalising gradients is to use the inverse (or more commonly the Moore-Penrose pseudoinverse) of the Fisher-information-matrix (Calculated as the second-derivative of the KL-divergence between a new policy w.r.t.theta
as a preconditioning factor on first-order gradient-descent methods; hence this library doesn't provide any way to do this.nngeometry
library is vastly over-complicated for most online-learning use-cases - more tailored for supervised-learning or offline-learning use-cases, requiring the use ofDataLoaders
and worker-processes to speed up approximations of Fisher-matrices.Proposed Additions
The
functorch
library meanwhile appears to provide a direct, simple way to do the expensive part of the computation - i.e. compute the Hessian matrices - and seems to be faster with its outer-loop optimisations speeding up the process vastly. The two things missing from this library are:I believe that a great deal of interest could be sparked in both this library, and more importantly trigger a resurgence in interest for higher-order optimisation methods which have time-and-time-again proven to be faster at converging than first-order methods; as the things that seem to be holding it back are hesitancies in research due to the points mentioned above; I nearly gave up on months of project work because of the issues myself before finding this paper/library that made the process so much easier. The headaches of workarounds still stood, but having some standardised function for a method to do this would make this library go from a 10/10 to an 11.
Simplified Example Usage