Hi everyone,
I was trying to retrieve per-sample gradients following the functorch documentation for a GRU-like model, but i get the following error:
Traceback (most recent call last):
File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 51, in <module>
ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 362, in wrapped
return _flat_vmap(
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 35, in fn
return f(*args, **kwargs)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 489, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\eager_transforms.py", line 1241, in wrapper
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 35, in fn
return f(*args, **kwargs)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\eager_transforms.py", line 1111, in wrapper
output = func(*args, **kwargs)
File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 26, in compute_loss_stateless_model
prediction = fmodel(params, buffers, sample.unsqueeze(1), state.unsqueeze(1))
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\make_functional.py", line 282, in forward
return self.stateless_model(*args, **kwargs)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 20, in forward
x, _ = self.recurrent(x, hx)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\rnn.py", line 955, in forward
result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: Batching rule not implemented for aten::unsafe_split.Tensor. We could not generate a fallback.
Vanilla RNN works correctly. The code i've used is the following:
from functools import partial
from typing import Type, Union
import torch
from functorch import grad, make_functional_with_buffers, vmap
class Recurrent(torch.nn.Module):
def __init__(
self,
recurrent_layer: Union[Type[torch.nn.GRU], Type[torch.nn.RNN]],
input_size: int,
hidden_size: int,
output_size: int,
) -> None:
super().__init__()
self.recurrent = recurrent_layer(input_size=input_size, hidden_size=hidden_size, batch_first=False)
self.fc = torch.nn.Linear(hidden_size, output_size)
def forward(self, x: torch.Tensor, hx: torch.Tensor) -> torch.Tensor:
x, _ = self.recurrent(x, hx)
x = self.fc(torch.relu(x))
return x
def compute_loss_stateless_model(fmodel, params, buffers, sample, target, state):
prediction = fmodel(params, buffers, sample.unsqueeze(1), state.unsqueeze(1))
loss = torch.nn.functional.mse_loss(prediction, target.unsqueeze(1))
return loss
if __name__ == "__main__":
T, B, D, H, O = 128, 64, 64, 256, 1
x = torch.rand(T, B, D)
t = torch.ones(T, B, O)
hx = torch.zeros(1, B, H)
gru = Recurrent(torch.nn.GRU, D, H, O)
rnn = Recurrent(torch.nn.RNN, D, H, O)
# functional RNN + vmap
frnn, params, buffers = make_functional_with_buffers(rnn)
ft_compute_grad = grad(partial(compute_loss_stateless_model, frnn))
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1, 1))
ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
for g in ft_sample_grads:
print(g.shape)
# functional GRU + vmap
fgru, params, buffers = make_functional_with_buffers(gru)
ft_compute_grad = grad(partial(compute_loss_stateless_model, fgru))
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1, 1))
ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
for g in ft_sample_grads:
print(g.shape)
The collected environment is the following:
PyTorch version: 1.13.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 Pro
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.8.10 (tags/v3.8.10:3d8993a, May 3 2021, 11:48:03) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19044-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070 SUPER
Nvidia driver version: 516.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] functorch==1.13.0
[pip3] mypy==0.931
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==1.8.3.post1
[pip3] torch==1.13.0
[pip3] torchmetrics==0.11.0
[conda] Could not collect
Hi everyone, I was trying to retrieve per-sample gradients following the functorch documentation for a GRU-like model, but i get the following error:
Vanilla RNN works correctly. The code i've used is the following:
The collected environment is the following:
Thank you, Federico