pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Is there a way to parallelize or accelerate a loop of column-by-column jvp? #1043

Open kwmaeng91 opened 1 year ago

kwmaeng91 commented 1 year ago

Hi, experts. I am currently calculating a Jacobian column-by-column and calculating the squared sum of each column to calculate the Trace of the Jacobian. The code looks something like this:

def jvp_func(x, tgt):
    return jvp(net, (x,), (tgt,))

tr = 0
for j in range(x[0].shape[0]):
    tgt = torch.zeros_like(x)
    tgt[:, j] = 1.
    _, grad = vmap(jvp_func)(x, tgt)
    tr += torch.sum(grad * grad, dim=1)

As you can see, my code calculates a batched Jacobian column by column (inside each j loop) and calculates the Trace. (motivated by this code: https://github.com/facebookresearch/jacobian_regularizer/blob/main/jacobian/jacobian.py) I am mainly doing this instead of calculating the entire Jacobian at once because the entire Jacobian is huge and it blows up the memory. However, this code is quite slow. I am not sure if this code is doing a lot of redundant computation, e.g., I wonder if net(x) is being calculated repetitively on each loop of j.

Is there a way to parallelize the j loop, or at least remove any repetitive computation for each j loop to speed up the current code? I briefly looked at functorch.compile.ts_compile but was not able to make it work, and am not sure if that is something that can be helpful.

Any suggestions will be highly appreciated!

Thank you, Best regards, Kiwan

zou3519 commented 1 year ago

However, this code is quite slow. I am not sure if this code is doing a lot of redundant computation, e.g., I wonder if net(x) is being calculated repetitively on each loop of j.

Yes, net(x) is being calculated repetitively on each loop of j. We either need something like jax.linearize or a good backend compiler with CSE to avoid this (cc @Chillee - would AOTAutograd help here? I don't know where CSE is implemented).

Is there a way to parallelize the j loop

I assume replacing the j loop with a vmap blows up your memory (you said that the entire jacobian is huge). So instead of replacing the j loop with a vmap, we can vmap it in chunks.

For example, instead of running the for-loop over all j, we could do it two js at a time:

def jvp_func(x, tgt):
    return jvp(net, (x,), (tgt,))

def make_tgt(j):
    tgt = torch.zeros_like(x)
    tgt[:, j] = 1.
    return tgt

def make_tgts(js):
    return torch.stack([make_tgt(j) for j in js])

def compute_tr(tgt):
    _, grad = vmap(jvp_func)(x, tgt)
    return torch.sum(grad * grad, dim=1)

tr = 0
for js in [ [0, 1], [2, 3], [4, 5], ... ]:
    tgts = make_tgts(js=[0, 1])
    tr0, tr1 = vmap(compute_tr)(tgts)
    tr = tr + tr0 + tr1
kwmaeng91 commented 1 year ago

Thanks for the help! The code became much faster by doing a vmap over a certain number of iterations. I am still interested in learning if I can easily avoid doing net(x) every time, so I will wait for @Chillee to respond. I think for torch.autograd, this can be done by using torch.autograd.grad(), with retain_graph=True? I wonder if there is something equivalent for functorch, if some cool new compiler cannot help.

zou3519 commented 1 year ago

torch.autograd, this can be done by using torch.autograd.grad(), with retain_graph=True

If you're okay with computing vjp instead of jvp, then you can use functorch.vjp (the transform version of torch.autograd.grad) to compute net(x) once and compute vjp multiple times