pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

vmap + GRU #1089

Closed belerico closed 1 year ago

belerico commented 1 year ago

Hi everyone, I was trying to retrieve per-sample gradients following the functorch documentation for a GRU-like model, but i get the following error:

Traceback (most recent call last):
  File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 51, in <module>
    ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 362, in wrapped
    return _flat_vmap(
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 489, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\eager_transforms.py", line 1241, in wrapper
    results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\eager_transforms.py", line 1111, in wrapper
    output = func(*args, **kwargs)
  File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 26, in compute_loss_stateless_model
    prediction = fmodel(params, buffers, sample.unsqueeze(1), state.unsqueeze(1))
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\make_functional.py", line 282, in forward
    return self.stateless_model(*args, **kwargs)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 20, in forward
    x, _ = self.recurrent(x, hx)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\rnn.py", line 955, in forward
    result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: Batching rule not implemented for aten::unsafe_split.Tensor. We could not generate a fallback.

Vanilla RNN works correctly. The code i've used is the following:

from functools import partial
from typing import Type, Union

import torch
from functorch import grad, make_functional_with_buffers, vmap

class Recurrent(torch.nn.Module):
    def __init__(
        self,
        recurrent_layer: Union[Type[torch.nn.GRU], Type[torch.nn.RNN]],
        input_size: int,
        hidden_size: int,
        output_size: int,
    ) -> None:
        super().__init__()
        self.recurrent = recurrent_layer(input_size=input_size, hidden_size=hidden_size, batch_first=False)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor, hx: torch.Tensor) -> torch.Tensor:
        x, _ = self.recurrent(x, hx)
        x = self.fc(torch.relu(x))
        return x

def compute_loss_stateless_model(fmodel, params, buffers, sample, target, state):
    prediction = fmodel(params, buffers, sample.unsqueeze(1), state.unsqueeze(1))
    loss = torch.nn.functional.mse_loss(prediction, target.unsqueeze(1))
    return loss

if __name__ == "__main__":
    T, B, D, H, O = 128, 64, 64, 256, 1
    x = torch.rand(T, B, D)
    t = torch.ones(T, B, O)
    hx = torch.zeros(1, B, H)
    gru = Recurrent(torch.nn.GRU, D, H, O)
    rnn = Recurrent(torch.nn.RNN, D, H, O)

    # functional RNN + vmap
    frnn, params, buffers = make_functional_with_buffers(rnn)
    ft_compute_grad = grad(partial(compute_loss_stateless_model, frnn))
    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1, 1))
    ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
    for g in ft_sample_grads:
        print(g.shape)

    # functional GRU + vmap
    fgru, params, buffers = make_functional_with_buffers(gru)
    ft_compute_grad = grad(partial(compute_loss_stateless_model, fgru))
    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1, 1))
    ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
    for g in ft_sample_grads:
        print(g.shape)

The collected environment is the following:

PyTorch version: 1.13.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Pro
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.10 (tags/v3.8.10:3d8993a, May  3 2021, 11:48:03) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19044-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070 SUPER
Nvidia driver version: 516.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] functorch==1.13.0
[pip3] mypy==0.931
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==1.8.3.post1
[pip3] torch==1.13.0
[pip3] torchmetrics==0.11.0
[conda] Could not collect

Thank you, Federico

kshitij12345 commented 1 year ago

Fixed in #92291