pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

fails to derive jacobian with functorch after operation torch.stft #1077

Closed lylyhan closed 1 year ago

lylyhan commented 1 year ago

Hello, I am trying to derive the jacobian matrix of a short time fourier transform operation (via torch.stft) with respect to an input signal. the code i used was

import torch
import functorch
sig = torch.rand(2**13)

def stft_forward(sig):
    return torch.stft(sig, n_fft=2048).flatten()
J = functorch.jacfwd(stft_forward)(sig)

yet the following error message is triggered:

----> [5].../test_lstsq.ipynb#X16sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4) J = functorch.jacfwd(stft_forward)(sig)

File ~/miniconda3/envs/icassp/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:961, in jacfwd.<locals>.wrapper_fn(*args)
    958     _, jvp_out = output
    959     return jvp_out
--> 961 results = vmap(push_jvp)(basis)
    962 if has_aux:
    963     results, aux = results

File ~/miniconda3/envs/icassp/lib/python3.9/site-packages/functorch/_src/vmap.py:365, in vmap.<locals>.wrapped(*args, **kwargs)
    363 try:
    364     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 365     batched_outputs = func(*batched_inputs, **kwargs)
    366     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    367 finally:

File ~/miniconda3/envs/icassp/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:954, in jacfwd.<locals>.wrapper_fn.<locals>.push_jvp(basis)
    953 def push_jvp(basis):
--> 954     output = jvp(f_wrapper, primals, basis, has_aux=has_aux)
    955     if has_aux:
...
    605     input = input.view(input.shape[-signal_dim:])
--> 606 return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
    607                 normalized, onesided, return_complex)

RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor, this is not supported. Tensor is of size [1025, 17, 2] while the given forward gradient is of size [17, 1025, 2].

Would people know what aspect of the torch.stft operation caused this gradient shape incompatibility issue and how one may work around this? thanks in advance!

kshitij12345 commented 1 year ago

Thanks for reporting. I tried to reproduce the error but couldn't. The following snippet works for me with latest version of PyTorch/Functorch. Which version are you using?

import torch
import functorch
sig = torch.rand(2**13)

def stft_forward(sig):
    return torch.stft(sig, n_fft=2048, return_complex=True).flatten()  # return_complex=True is required for real inputs
J = functorch.jacfwd(stft_forward)(sig)
lylyhan commented 1 year ago

thanks for your reply. My functorch and torch versions are the following:

>>> functorch.__version__
'0.2.1'
>>> torch.__version__
'1.12.1'

Which version did you use that avoided this error?

kshitij12345 commented 1 year ago

Thanks for confirming @lylyhan ! It works for me with 1.13.0.

NOTE: With 1.13, you don't need to install functorch separately as it is already bundled with PyTorch.

zou3519 commented 1 year ago

Closing as fixed, but please let us know if you're still experiencing the problem on a newer version of pytorch/functorch.