pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

function(Jacobian)-dot-vector and vector-Jacobian-vector function #1056

Open veya2ztn opened 1 year ago

veya2ztn commented 1 year ago

Hi,

I'd like to use functorh to realize following loss:

Question demonstrate

assume the

there exists the Jacobian matrix $(O\times I)$ marked $J\alpha^\gamma=\frac{\partial y^{\gamma}}{\partial x\alpha}$

I am want to calculate two term

$$ L1=\sum\gamma(\sum\alpha J_\alpha^{\gamma}-1)^2 $$

$$ L2 =\sum\gamma [\sum\alpha (J_\alpha^{\gamma})^2-1]^2 $$

as well as there gradient of $W$, $\frac{\partial L1}{\partial W}$ and $\frac{\partial L2}{\partial W}$ for the gradient decent update.

This is easier to realize with the help of functorch , I post a toy example below

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
import torch.nn.functional as F
import functorch
from functorch import jacrev,jacfwd
from functorch import make_functional, vmap, grad
B=200
I=100
O=300
class MyModel(torch.nn.Module):
    def __init__(self, in_chan, out_chan):
        super().__init__()
        self.backbone = torch.nn.Linear(in_chan, out_chan,bias=False)
    def forward(self,x):
        return self.backbone(x)**2
model= MyModel(I, O).cuda()
x    = torch.randn(B, I).cuda()
cotangents = torch.ones(B,I).cuda()
func_model, params = make_functional(model)

### ---> to calculate the dL1/dw term
def Normlization_Term_1(params,x):
        return ((functorch.jvp(lambda x:func_model(params,x), (x,), (cotangents,)
            )[1]-1)**2).mean()
Derivation_Term_1 = jacrev(Normlization_Term_1, argnums=0)(params, x)

### ---> to calculate the dL2/dw term
Normlization_Term_2= lambda params,x:(
    (vmap(jacrev(func_model, argnums=1), (None, 0))(params, x)**2).sum(-1)-1
    )**2
Derivation_Term_2 = jacrev(Normlization_Term_2, argnums=0)(params, x)

Problem

The idea is to calculate:

I suppose it is because we have to access the full Jacobian matrix $J_\alpha^{\gamma}$ in the second case which is too large to store during computation.

The OOM issue is also reported by https://github.com/pytorch/functorch/issues/636#issue-1185946292 and (possibly) solved by the recent update with chunks option in https://github.com/pytorch/functorch/issues/680#issue-1197453691

My ideas are


I check the source code in jvp , it directly use the dual mode of pytorch-fwdad and return the jvp term directly from _unpack_dual , so I am afraid this problem may beyond the scope in functorch pipline.

Anyway, I look forward your discussion.