pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Jacobian of matrix exponential yields an "in-place" runtime error #715

Open jenninglim opened 2 years ago

jenninglim commented 2 years ago

I am trying to calculate the jacobian of a matrix exponential. However, it yields the following error

RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.

Is there a fix for this?

The error can be reproduced by the following:

from functorch import vmap, jacrev, jacfwd
import torch

def expm(a):
    return torch.linalg.matrix_exp(a)

x = torch.eye(3)
jacfwd(expm)(x)

Thanks.

Chillee commented 2 years ago

Seems like a composite compliance issue @zou3519

I thought that functionalization should be able to fix this, but seems to segfault. Any ideas @bdhirsh?

bdhirsh commented 2 years ago

I have a local fix for jacfwd(functionalize(f)), but it doesn't fix the problem.

Is this a composite compliance issue? It looks like matrix_exp isn't a composite op (native function entry).

bdhirsh commented 2 years ago

Oh hmm, it looks like it's because the derivative formula for matrix_exp isn't "composite compliant" (although in this case, the derivative formula isn't an op, it's just a function. Defined here, where it uses a bunch of view + inplace ops)

I wonder if there's a reasonable way to get functionalization to interpose in the backward

Chillee commented 2 years ago

Why shouldn't the backwards pass be functionalizable?