Closed newalexander closed 2 years ago
@newalexander I'm a bit confused. The direct analogue to torch.autograd.functional.jacobian should be functorch.jacrev, but in the above script it looks like we are comparing vmap(jacrev with torch.autograd.functional.jacobian.
Edit: I am wrong, I see the difference now
At any rate we are interested in investigating and closing performance gaps we find between functorch and torch.autograd.functional
To answer your question, the 2 approaches are morally identical.
PyTorch Jacobian is actually using an earlier version of vmap (under the hood) when you set vectorize=True
. It's possible there might be slightly more overhead using functorch's vmap, but I'd have to actually run the model to know (it's a fairly small difference, about 5% here) so I wouldn't be shocked if it was overhead differences. In this case, since it looks like you're running on CPU, overheads tend not to be hidden, which might explain the differences.
As for improving performance, there's a lot of possibilities. It's possible that we have suboptimal batching rules somewhere (or batching rules with more overhead than needed). We also have some prototype compilation-ish features that could work for this situation.
@newalexander I investigated it, and 2 main observations:
functorch_jacobian
and pytorch_jacobian
. In the first case you are actually computing the batched jacobian, while in the second case, you are computing a single jacobian on a function that's summing at the end. In this case, they're identical, and it seems like the second option is a bit more efficient. But... we can do the same thing with functorch, and it appears to be a bit faster than autograd.functional
.Here, functorch_jacobian2
is the one implemented with the sum
trick.
pytorch_jacobian 964.2601013183594
functorch_jacobian 1162.0521545410156
functorch_jacobian2 921.9169616699219
functorch_jacobian2
and pytorch_jacobian
are pretty much doing the same computation, and the only runtime difference between them is overhead . We have an experimental API called AOTAutograd that can be used in this case, that'll trace out the computation and can pass it to a compiler (in this case Torchscript). Since overhead is the dominating factor here, applying this API ends up speeding it up by a lot.pytorch_jacobian 964.2601013183594
functorch_jacobian 1162.0521545410156
functorch_jacobian2 921.9169616699219
returned_function 247.47848510742188
One note about the code - it's not strictly doing the same thing right now, since aot_function
only propagates gradients through the input and the output. If you want to use aot_function
(which is still a prototype feature!) in practice, you should make sure that all of the stuff is passed in as function arguments.
@Chillee @zou3519 Thank you both for your quick and informative comments on this matter. I hadn't considered that functorch would benefit from repeating the sum
trick, and I'll be sure to keep an eye on the aot_function
progress. (We're hoping to use functorch
as a base for a physics-informed neural network library analogous to deepxde
, so having these easily-batched derivatives is great.)
(We're hoping to use
functorch
as a base for a physics-informed neural network library analogous todeepxde
, so having these easily-batched derivatives is great.)
If you run into any other problems with functorch, or have additional feature requests or feedback, please don't hesitate to open a new issue!
@newalexander Fwiw, I'm not totally sure that the sum
trick really matters for performance much. When I increased the size the gap narrowed.
Btw, for your use case, are you actually using CPU for computation, or was this just testing it out?
@Chillee In practice, we want to do computations on GPU, this was just validation.
Apologies for the lengthy wall of code, but, for the curious, below is a somewhat minimal example showing how we can formulate and solve a simple PINN problem in functorch. We're not (currently) using the make_functional
API, but being able to efficiently take derivatives of NN outputs wrt inputs is key.
If possible, a super-valued additional feature would be lazy evaluation of derivatives. For example, in get_interior_loss
below, the full Hessian of the network output is calculated, but only the diagonal entries are needed in the PDE. This is something, I believe, that the deepxde
PINN library does (e.g., https://github.com/lululxvi/deepxde/blob/master/deepxde/gradients.py#L6).
We're computing batchwise jacobians of a network with respect to its inputs.
So the
functorch
jacobian is a bit slower to be calculated than the basepytorch
jacobian. Is this to be expected, should I consider the difference to be fairly inconsequential, or am I conducting this comparison in an incorrect way? If it's to be expected, are there any planned optimizations ofvmap(jacrev)
planned?Very nice work on the library by the way.