pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

[Question] Forward-mode AD `jvp` is slower than reverse-mode AD `vjp` #1083

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

I'm using functorch to compute the hessian-vector product (hvp) for my model. I have noticed that the hessian matrix is symmetric and hvp and vhp should be the transpose of each other. I'm wondering what's the best way to compute the hvp in practice (speed sensitive).

The documentation functorch Tutorials: Computing Hessian-vector products says:

Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn’t need to construct an Autograd graph and save intermediates for backward.

The reverse + forward approach is more memory efficient. How about the time performance? Which one is recommended? Many thanks!

A small snippet copied from Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms:

import torch
from functorch import jvp, grad, vjp
from torch.utils.benchmark import Timer

def f(x):
    return x.sin().sum()

def hvp_revfwd(f, primals, tangents):
    return jvp(grad(f), primals, tangents)[1]

def hvp_revrev(f, primals, tangents):
    _, vjp_fn = vjp(grad(f), *primals)
    return vjp_fn(*tangents)

def get_perf(first, first_descriptor, second, second_descriptor):
    """takes torch.benchmark objects and compares delta of second vs first."""
    faster = second.times[0]
    slower = first.times[0]
    gain = (slower - faster) / slower
    if gain < 0:
        gain *= -1
    final_gain = gain * 100
    print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")

x = torch.randn(2048)
tangent = torch.randn(2048)

revfwd_timer = Timer("hvp_revfwd(f, (x,), (tangent,))", globals=globals()).timeit(10000)
revrev_timer = Timer("hvp_revrev(f, (x,), (tangent,))", globals=globals()).timeit(10000)

print(revfwd_timer)
print(revrev_timer)

get_perf(revfwd_timer, 'rev + fwd', revrev_timer, 'rev + rev')

Result:

<torch.utils.benchmark.utils.common.Measurement object at 0x7f4a7c09a670>
hvp_revfwd(f, (x,), (tangent,))
  384.73 us
  1 measurement, 10000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f4b38086400>
hvp_revrev(f, (x,), (tangent,))
  281.31 us
  1 measurement, 10000 runs , 1 thread
 Performance delta: 26.8806 percent improvement with rev + rev

The reverse-mode AD vjp is ~26% faster than the forward-mode AD jvp. I haven't tested the memory cost yet.

zou3519 commented 1 year ago

I think the answer is: it really depends on the function and the sizes of the inputs. There's no simple answer so benchmarking and seeing which one is faster for your use case is our recommendation.

In terms of operator coverage, the reverse-over-reverse is guaranteed to cover more PyTorch operations due to PyTorch's reverse-more AD being around a lot longer than PyTorch's forward-mode AD.

cc @soulitzer @albanD if you disagree or have more to add.

XuehaiPan commented 1 year ago

In terms of operator coverage, the reverse-over-reverse is guaranteed to cover more PyTorch operations due to PyTorch's reverse-more AD being around a lot longer than PyTorch's forward-mode AD.

This makes sense to me. Thanks.