pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

`vmap(jacrev)` is slower than `functional.jacobian` #328

Closed newalexander closed 2 years ago

newalexander commented 2 years ago

We're computing batchwise jacobians of a network with respect to its inputs.

import torch
import timeit

from torch import nn
from functorch import vmap, jacrev
from torch.autograd.functional import jacobian

def functorch_jacobian():
    """calculate a jacobian tensor along a batch of inputs. returns something of size
    `batch_size` x `output_dim` x `input_dim`"""
    return vmap(jacrev(model))(points)

def pytorch_jacobian():
    """calculate a jacobian tensor along a batch of inputs. returns something of size
    `batch_size` x `output_dim` x `input_dim`"""
    def _func_sum(points):
        return model(points).sum(dim=0)
    return jacobian(_func_sum, points, create_graph=True, vectorize=True).permute(1,0,2)

torch.manual_seed(1234)

n_input, n_output, n_batch, n_hidden = 3, 5, 128, 64

model = nn.Sequential(nn.Linear(n_input, n_hidden), nn.Tanh(), nn.Linear(n_hidden, n_output))
points = torch.rand((n_batch, n_input))

%timeit pytorch_jacobian()  # 691 µs ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit functorch_jacobian()  # 732 µs ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

So the functorch jacobian is a bit slower to be calculated than the base pytorch jacobian. Is this to be expected, should I consider the difference to be fairly inconsequential, or am I conducting this comparison in an incorrect way? If it's to be expected, are there any planned optimizations of vmap(jacrev) planned?

Very nice work on the library by the way.

zou3519 commented 2 years ago

@newalexander I'm a bit confused. The direct analogue to torch.autograd.functional.jacobian should be functorch.jacrev, but in the above script it looks like we are comparing vmap(jacrev with torch.autograd.functional.jacobian.

Edit: I am wrong, I see the difference now

At any rate we are interested in investigating and closing performance gaps we find between functorch and torch.autograd.functional

Chillee commented 2 years ago

To answer your question, the 2 approaches are morally identical.

PyTorch Jacobian is actually using an earlier version of vmap (under the hood) when you set vectorize=True. It's possible there might be slightly more overhead using functorch's vmap, but I'd have to actually run the model to know (it's a fairly small difference, about 5% here) so I wouldn't be shocked if it was overhead differences. In this case, since it looks like you're running on CPU, overheads tend not to be hidden, which might explain the differences.

As for improving performance, there's a lot of possibilities. It's possible that we have suboptimal batching rules somewhere (or batching rules with more overhead than needed). We also have some prototype compilation-ish features that could work for this situation.

Chillee commented 2 years ago

@newalexander I investigated it, and 2 main observations:

  1. The computation is not actually identical between functorch_jacobian and pytorch_jacobian. In the first case you are actually computing the batched jacobian, while in the second case, you are computing a single jacobian on a function that's summing at the end. In this case, they're identical, and it seems like the second option is a bit more efficient. But... we can do the same thing with functorch, and it appears to be a bit faster than autograd.functional.

Here, functorch_jacobian2 is the one implemented with the sum trick.

pytorch_jacobian 964.2601013183594
functorch_jacobian 1162.0521545410156
functorch_jacobian2 921.9169616699219
  1. In actuality, however, for these input sizes/modules, you end up being dominated by overhead. functorch_jacobian2 and pytorch_jacobian are pretty much doing the same computation, and the only runtime difference between them is overhead . We have an experimental API called AOTAutograd that can be used in this case, that'll trace out the computation and can pass it to a compiler (in this case Torchscript). Since overhead is the dominating factor here, applying this API ends up speeding it up by a lot.
pytorch_jacobian 964.2601013183594
functorch_jacobian 1162.0521545410156
functorch_jacobian2 921.9169616699219
returned_function 247.47848510742188
Code
``` import time import torch from torch import nn from functorch import vmap, jacrev from functorch.compile import aot_function from torch.autograd.functional import jacobian def ts_compile(fx_g, inps): print("compiling") f = torch.jit.script(fx_g) f = torch.jit.freeze(f.eval()) return f def ts_compiler(f): return aot_function(f, ts_compile, ts_compile) def functorch_jacobian(points): """calculate a jacobian tensor along a batch of inputs. returns something of size `batch_size` x `output_dim` x `input_dim`""" return vmap(jacrev(model))(points) def functorch_jacobian2(points): """calculate a jacobian tensor along a batch of inputs. returns something of size `batch_size` x `output_dim` x `input_dim`""" def _func_sum(points): return model(points).sum(dim=0) return jacrev(_func_sum)(points).permute(1,0,2) def pytorch_jacobian(points): """calculate a jacobian tensor along a batch of inputs. returns something of size `batch_size` x `output_dim` x `input_dim`""" def _func_sum(points): return model(points).sum(dim=0) return jacobian(_func_sum, points, create_graph=True, vectorize=True).permute(1,0,2) torch.manual_seed(1234) n_input, n_output, n_batch, n_hidden = 3, 5, 128, 64 for n_hidden in [64]: print(n_input, n_output, n_batch, n_hidden) model = nn.Sequential(nn.Linear(n_input, n_hidden), nn.Tanh(), nn.Linear(n_hidden, n_output)).eval() points = torch.rand((n_batch, n_input)) k = pytorch_jacobian(points) v = functorch_jacobian(points) # needed due to the global variables changing, otherwise it caches the function and doesn't recompute when the model changes. functorch_jacobian3 = ts_compiler(lambda points: functorch_jacobian2(points)) m = functorch_jacobian3(points) assert torch.allclose(k, v) assert torch.allclose(k, m, rtol=1e-04) def bench(f, name): import time iters=5 for _ in range(5): f(points) begin = time.time() for _ in range(iters): f(points) print(f.__name__, (time.time()-begin)*1e6/iters) bench(pytorch_jacobian, name="pytorch") bench(functorch_jacobian, name="functorch1") bench(functorch_jacobian2, name="functorch2") bench(functorch_jacobian3, name="compiled functorch2") print() ```

One note about the code - it's not strictly doing the same thing right now, since aot_function only propagates gradients through the input and the output. If you want to use aot_function (which is still a prototype feature!) in practice, you should make sure that all of the stuff is passed in as function arguments.

newalexander commented 2 years ago

@Chillee @zou3519 Thank you both for your quick and informative comments on this matter. I hadn't considered that functorch would benefit from repeating the sum trick, and I'll be sure to keep an eye on the aot_function progress. (We're hoping to use functorch as a base for a physics-informed neural network library analogous to deepxde, so having these easily-batched derivatives is great.)

zou3519 commented 2 years ago

(We're hoping to use functorch as a base for a physics-informed neural network library analogous to deepxde, so having these easily-batched derivatives is great.)

If you run into any other problems with functorch, or have additional feature requests or feedback, please don't hesitate to open a new issue!

Chillee commented 2 years ago

@newalexander Fwiw, I'm not totally sure that the sum trick really matters for performance much. When I increased the size the gap narrowed.

Btw, for your use case, are you actually using CPU for computation, or was this just testing it out?

newalexander commented 2 years ago

@Chillee In practice, we want to do computations on GPU, this was just validation.

Apologies for the lengthy wall of code, but, for the curious, below is a somewhat minimal example showing how we can formulate and solve a simple PINN problem in functorch. We're not (currently) using the make_functional API, but being able to efficiently take derivatives of NN outputs wrt inputs is key.

If possible, a super-valued additional feature would be lazy evaluation of derivatives. For example, in get_interior_loss below, the full Hessian of the network output is calculated, but only the diagonal entries are needed in the PDE. This is something, I believe, that the deepxde PINN library does (e.g., https://github.com/lululxvi/deepxde/blob/master/deepxde/gradients.py#L6).

code ```python """ we're solving poisson's equation on a rectangular domain: D(x, y; u) = partial_xx u + partial_yy u + sin(pi x) sin(pi y), (x, y) in Omega = [0, 1] x [0, 1] L(y; u) = u(0, y) = 0 R(y; u) = u(1, y) = 0 B(x; u) = u(x, 0) = 0 T(x; u) = u(x, 1) = 0 let u: R^2 -> R be a neural network. we minimize the loss function sum_{(x, y) in Omega} | D(x, y; u) |^2 + sum_{x in [0, 1]} (| B(x; u) |^2 + | T(x; u) |^2 ) + sum_{y in [0, 1]} (| L(y; u) |^2 + | R(y; u) |^2 ) the learned solution can be compared to the known analytic solution: u(x, y) = sin(pi x) sin(pi y) / (2 pi^2) compare to Julia implementation: https://neuralpde.sciml.ai/dev/pinn/poisson/ """ import torch from torch import nn, optim from functorch import vmap, jacrev from tqdm import tqdm def sample_interior(n_points): """sample an N x 2 array from [0, 1] x [0, 1]""" points = torch.rand(n_points, 2) return points def sample_x_boundary(n_points, x_val): """sample an N x 2 array from an x boundary""" points = x_val * torch.ones(n_points, 2) points[:, 1] = torch.rand(n_points) return points def sample_y_boundary(n_points, y_val): """sample an N x 2 array from a y boundary""" points = y_val * torch.ones(n_points, 2) points[:, 0] = torch.rand(n_points) return points def batch_hessian(model, points): """calculate a hessian matrix along a batch of inputs. returns something of size `batch_size` x `output_dim` x `input_dim` x `input_dim`""" return vmap(jacrev(jacrev(model)))(points) def get_interior_loss(n_points, model, points=None, **kwargs): """loss in Omega = [0, 1] x [0, 1]""" if points is None: points = sample_interior(n_points) hessians = batch_hessian(model, points) residual = hessians[:, 0, 0, 0] + hessians[:, 0, 1, 1] + torch.sin(torch.pi * points[:, 0]) * torch.sin( torch.pi * points[:, 1]) return torch.mean(residual ** 2) def get_left_loss(n_points, model, points=None, **kwargs): """loss in {0} x [0, 1]""" if points is None: points = sample_x_boundary(n_points, 0) u_hat = model(points) error = u_hat return torch.mean(error ** 2) def get_right_loss(n_points, model, points=None, **kwargs): """loss in {1} x [0, 1]""" if points is None: points = sample_x_boundary(n_points, 1) u_hat = model(points) error = u_hat return torch.mean(error ** 2) def get_bottom_loss(n_points, model, points=None, **kwargs): """loss in [0, 1] x {0}""" if points is None: points = sample_y_boundary(n_points, 0) u_hat = model(points) error = u_hat return torch.mean(error ** 2) def get_top_loss(n_points, model, points=None, **kwargs): """loss in [0, 1] x {1}""" if points is None: points = sample_y_boundary(n_points, 1) u_hat = model(points) error = u_hat return torch.mean(error ** 2) def get_total_loss(n_interior, n_left, n_right, n_bottom, n_top, model, interior_points=None, left_points=None, right_points=None, bottom_points=None, top_points=None): """ interior_loss + left_loss + right_loss + bottom_loss + top_loss note that all loss conditions are equally weighted """ total_loss = 0. kwargs = {} for loss_func, n_pts, pts in zip([get_interior_loss, get_left_loss, get_right_loss, get_bottom_loss, get_top_loss], [n_interior, n_left, n_right, n_bottom, n_top], [interior_points, left_points, right_points, bottom_points, top_points]): total_loss += loss_func(n_pts, model, pts, **kwargs) return total_loss def get_model(n_hidden, n_layer, n_input, n_output, activation=nn.Tanh): layers = [nn.Linear(n_input, n_hidden), activation()] for _ in range(n_layer): layers.append(nn.Linear(n_hidden, n_hidden)) layers.append(activation()) layers.append(nn.Linear(n_hidden, n_output)) return nn.Sequential(*layers) def train_epoch_adam(n_interior, n_left, n_right, n_bottom, n_top, model, opt): # forward + backward + optimize total_loss = get_total_loss(n_interior, n_left, n_right, n_bottom, n_top, model) opt.zero_grad() total_loss.backward() opt.step() return total_loss def train(): n_interior, n_boundary = 2540, 160 torch.manual_seed(1234) n_input, n_output = 2, 1 n_hidden, n_layer = 20, 3 n_epochs = 4096 model = get_model(n_hidden, n_layer, n_input, n_output) opt = optim.Adam(model.parameters()) for _ in tqdm(range(n_epochs)): train_epoch_adam(n_interior, n_boundary, n_boundary, n_boundary, n_boundary, model, opt) return model ```