pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

`functorch.grad` slower than `torch.autograd.grad` #1085

Closed abdulfatir closed 1 year ago

abdulfatir commented 1 year ago

functorch.grad appears to be almost 2.5 times slower than torch.autograd.grad in my simple test below. Is this expected?

import timeit
import torch
import functorch

def test_func(z):
    return (z ** 2).sum()

def pytorch_grad(x):
    x.requires_grad_(True)
    return torch.autograd.grad(test_func(x), x)[0]

functorch_grad = functorch.grad(test_func)
z = torch.rand((2, ))

assert torch.allclose(functorch_grad(z), pytorch_grad(z))

%timeit functorch_grad(z) # 125 µs ± 369 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit pytorch_grad(z) # 51.9 µs ± 531 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Versions

Collecting environment information...
PyTorch version: 1.13.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.6.1 (x86_64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.15 (default, Nov 24 2022, 09:04:07)  [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==1.8.4.post0
[pip3] torch==1.13.0
[pip3] torchmetrics==0.11.0
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] pytorch-lightning         1.8.4.post0              pypi_0    pypi
[conda] torch                     1.13.0                   pypi_0    pypi
[conda] torchmetrics              0.11.0                   pypi_0    pypi
zou3519 commented 1 year ago

There is overhead associated with functorch.grad, so yes, it's expected that functorch.grad (by itself) is slower than autograd. The overhead is less noticeable the larger the program is. functorch.grad is a wrapper around torch.autograd.grad that does some additional bookkeeping; the overhead is necessary because one can arbitrarily nest functorch.grad with other function transforms (like vmap), but one is not able to do the same with torch.autograd.grad.

Is the overhead a problem for you? There is likely some low-hanging fruit we can do to cut some of the overhead and it's on our roadmap to ensure that the upcoming torch.compile feature helps cut overhead of functorch transforms.

abdulfatir commented 1 year ago

Thanks, I did look into the source code after posting this issue and realized that the behavior could be due to the additional bookkeeping. I am curious to know which bookkeeping operations slow things down though.

My use case was mainly the cleanliness of code provided by functorch API but I guess something similar could be achieved by writing a wrapping function over torch.autograd.grad.

zou3519 commented 1 year ago

If you're interested, https://docs.google.com/document/d/14qyaa3xIjmVxYiMLlIlQErunYgR_uR1WupsKMZlnGY4/edit# goes over the design (but it might assume knowledge about how PyTorch internals work). That and other information is linked over at the PyTorch Wiki for the curious (https://github.com/pytorch/pytorch/wiki/Core-Frontend-Onboarding)

abdulfatir commented 1 year ago

Thank you!