as well as there gradient of $W$, $\frac{\partial L1}{\partial W}$ and $\frac{\partial L2}{\partial W}$ for the gradient decent update.
This is easier to realize with the help of functorch , I post a toy example below
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
import torch.nn.functional as F
import functorch
from functorch import jacrev,jacfwd
from functorch import make_functional, vmap, grad
B=200
I=100
O=300
class MyModel(torch.nn.Module):
def __init__(self, in_chan, out_chan):
super().__init__()
self.backbone = torch.nn.Linear(in_chan, out_chan,bias=False)
def forward(self,x):
return self.backbone(x)**2
model= MyModel(I, O).cuda()
x = torch.randn(B, I).cuda()
cotangents = torch.ones(B,I).cuda()
func_model, params = make_functional(model)
### ---> to calculate the dL1/dw term
def Normlization_Term_1(params,x):
return ((functorch.jvp(lambda x:func_model(params,x), (x,), (cotangents,)
)[1]-1)**2).mean()
Derivation_Term_1 = jacrev(Normlization_Term_1, argnums=0)(params, x)
### ---> to calculate the dL2/dw term
Normlization_Term_2= lambda params,x:(
(vmap(jacrev(func_model, argnums=1), (None, 0))(params, x)**2).sum(-1)-1
)**2
Derivation_Term_2 = jacrev(Normlization_Term_2, argnums=0)(params, x)
Problem
The idea is to calculate:
$\sum\alpha J\alpha^{\gamma}$ this term is easy to realize by the functorch.jvp and torch.autograd.functional.jvp by setting the cotangents as all-one tensor torch.ones(B,I). If we do the summation $\sum_\gamma$ in the wrapped function and pass it to calculate the Jacobian of model's parameter $W$, it run fast and cost small memory.
However, when calculate the next term $\sum\alpha (J\alpha^{\gamma})^2$ . There is no jvp function here and I have to create the full Jacobian of primal followed with a .sum() function to obtain result. In such a case, we will face OOM problem. My machine is A100-80G.
I suppose it is because we have to access the full Jacobian matrix $J_\alpha^{\gamma}$ in the second case which is too large to store during computation.
Can we build a function in native that produce the F(Jacobian)-dot-vector output vector $f(J)\cdot \vec{n}\rightarrow \vec{v}$
if the $f:x\rightarrow x$ , then it is the functorch.jvp $J\cdot \vec{n}\rightarrow \vec{v}$
if the $f: x\rightarrow x^2$, the it is the second term in my example. But this time, since it doesn't to access the full Jacobian, it becomes more memory efficient.
some usages of Jacobian function would only require
Jacobian-dot-vector produce a vector, covered by the functorch.jvp
vector-dot-Jacobian produce a vector, covered by the functorch.vjp
vecotr-dot-jacobian-dot-vector produce a scalar, need to be realized by the jvp or vjp
When do gradient calculation on those output, the memory usage to store intermediate tensor is around D of vector x N of parameters. Is that possible to realize a native vecotr-dot-jacobian-dot-vector without access those large intermediate and become memory efficient?
I check the source code in jvp , it directly use the dual mode of pytorch-fwdad and return the jvp term directly from _unpack_dual , so I am afraid this problem may beyond the scope in functorch pipline.
Hi,
I'd like to use
functorh
to realize following loss:Question demonstrate
assume the
primal
) is $I$ and we use $x_\alpha$ mark each element.there exists the Jacobian matrix $(O\times I)$ marked $J\alpha^\gamma=\frac{\partial y^{\gamma}}{\partial x\alpha}$
I am want to calculate two term
$$ L1=\sum\gamma(\sum\alpha J_\alpha^{\gamma}-1)^2 $$
$$ L2 =\sum\gamma [\sum\alpha (J_\alpha^{\gamma})^2-1]^2 $$
as well as there gradient of $W$, $\frac{\partial L1}{\partial W}$ and $\frac{\partial L2}{\partial W}$ for the gradient decent update.
This is easier to realize with the help of
functorch
, I post a toy example belowProblem
The idea is to calculate:
functorch.jvp
andtorch.autograd.functional.jvp
by setting thecotangents
as all-one tensortorch.ones(B,I)
. If we do the summation $\sum_\gamma$ in the wrapped function and pass it to calculate the Jacobian of model's parameter $W$, it run fast and cost small memory.jvp
function here and I have to create the full Jacobian ofprimal
followed with a.sum()
function to obtain result. In such a case, we will face OOM problem. My machine is A100-80G.I suppose it is because we have to access the full Jacobian matrix $J_\alpha^{\gamma}$ in the second case which is too large to store during computation.
The OOM issue is also reported by https://github.com/pytorch/functorch/issues/636#issue-1185946292 and (possibly) solved by the recent update with
chunks
option in https://github.com/pytorch/functorch/issues/680#issue-1197453691My ideas are
Can we build a function in native that produce the
F(Jacobian)-dot-vector
output vector $f(J)\cdot \vec{n}\rightarrow \vec{v}$if the $f:x\rightarrow x$ , then it is the
functorch.jvp
$J\cdot \vec{n}\rightarrow \vec{v}$if the $f: x\rightarrow x^2$, the it is the second term in my example. But this time, since it doesn't to access the full Jacobian, it becomes more memory efficient.
some usages of Jacobian function would only require
Jacobian-dot-vector
produce a vector, covered by thefunctorch.jvp
vector-dot-Jacobian
produce a vector, covered by thefunctorch.vjp
vecotr-dot-jacobian-dot-vector
produce a scalar, need to be realized by thejvp
orvjp
When do gradient calculation on those output, the memory usage to store intermediate tensor is around
D of vector
xN of parameters
. Is that possible to realize a nativevecotr-dot-jacobian-dot-vector
without access those large intermediate and become memory efficient?I check the source code in
jvp
, it directly use thedual
mode ofpytorch-fwdad
and return thejvp
term directly from_unpack_dual
, so I am afraid this problem may beyond the scope infunctorch
pipline.Anyway, I look forward your discussion.