pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

vmap and forward-mode AD fail sometimes on in-place views #999

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

The Problem

import torch
from functorch import jvp, vmap
from functools import partial

B = 2

def f(x, y):
    x = x.clone()
    view = x[0]
    x.copy_(y)
    return view, x

def push_jvp(x, y, yt):
    return jvp(partial(f, x), (y,), (yt,))

x = torch.randn(2, B, 6)
y = torch.randn(2, 6, B)
yt = torch.randn(2, 6, B)
outs, tangents = vmap(push_jvp, in_dims=(1, 2, 2))(x, y, yt)

raises the following:

RuntimeError: vmap: Calling Tensor.as_strided is not supported unless the batch dims being vmapped over are at the front of
the tensor (in memory layout). When they are not at the front of the tensor this operation can be error prone so we actively
 discourage it; please file us a bug report and/or try to express the as_strided operation in terms of PyTorch view operatio
ns

If I am understanding what is going on correctly, the root cause of the problem is that, ignoring vmap for a second, in x.copy_(y), x is a regular Tensor and y is a dual tensor:

Now, if y.tangent is a BatchedTensor, then calling as_strided on it may raise the above error message.

Is this actually a problem?

Previously, our approach was to say that vmap x jvp composition only works when the user must only vmap over dimension 0. However, that's not quite correct -- if the user users non-contiguous tensors, then it'll run into this problem. Also, vmap x jvp can produce tensors where the batch dimension is not at 0, so the user has no control over this.

Potential solutions

  1. When a tangent gets propagated to views as a result of an in-place operation, instead of calling as_strided, we should call the original view operation. This means we should save the original view operation somewhere.
  2. (From Jeffrey) An alternative to (1) is: instead of calling as_strided, figure out what the correct non-as_strided view operation(s) are by reading the sizes/sizes/storage_offset, and call that instead.
  3. It is possible to write a batching rule for a "safe as_strided". An as_strided call is safe if it does not expose memory that was not previously exposed in the Tensor. We would (a) add a safe_as_strided operator, (b) save some metadata on if a view Tensor was created from a base through a chain of "safe" operations or not, and (c) dispatch to either safe_as_strided or as_strided

Thoughts? cc @soulitzer @albanD

albanD commented 1 year ago

Just make https://github.com/pytorch/pytorch/blob/e4ea751810bd1b27a105ac43ce2c8c84fabc1167/c10/core/TensorImpl.h#L1084 return false for BatchedTensor and this will go away! :)

soulitzer commented 1 year ago

Hmm you mean GradWrapper right?

zou3519 commented 1 year ago

How does this work? Is there special logic in forward-mode AD that handles support_as_strided?

soulitzer commented 1 year ago

Its not forward AD specific. There's logic in ADInplaceOrView to check for the tensor's support_as_strided method, so this would apply to all views.

albanD commented 1 year ago

There is special logic in all of autograd for this :) It basically will replace all the places where we would usually call as_strided() to now call the original view op.

We use this to be able to handle conjuate view, cross dtype views (which can't be replaced with as_strided) or nested tensor (which can't handle generic as_strided)

zou3519 commented 1 year ago

@albanD do you have a sense of how much overhead this adds?

Making this return false for BatchedTensor doesn't actually work because BatchedTensor isn't directly involved in autograd -- autograd sees the TensorWrapper / GradWrapper. As Jeffrey mentioned we would have to set support_as_strided=False for GradWrapper, which would mean that even if vmap is not involved (e.g. the user just uses functorch.{jvp, grad}), they would take the performance hit.

albanD commented 1 year ago

The difference is noticeable for very small ops but not a dealbreaker either:

In [12]: a = torch.view_as_real(torch.rand(2, dtype=torch.complex64, requires_grad=True).clone())

In [13]: b = torch.rand(4, requires_grad=True).clone().view(2, 2)

In [14]: %timeit tmp = a.view_as(a)
969 ns ± 2.32 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [15]: %timeit tmp = b.view_as(b)
866 ns ± 12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [16]: %timeit tmp = a.add_(1)
3.27 µs ± 11 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [17]: %timeit tmp = b.add_(1)
2.91 µs ± 14.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

The first one is the cost to track a full view op instead of nothing. The second is replaying all the views instead of just one as_strided.

zou3519 commented 1 year ago

Thanks Alban. A few hundred nanoseconds is not that bad

zou3519 commented 1 year ago

Had some more offline discussion with Alban.

It's important to note that:

So, here's the current plan on record:

1) First we should see if we can easily prove the as_strided is "safe". If we can, then no problem, we can write a batching rule for it. 2) If it is not easy to prove the as_strided is "safe", then we may need to thread that information through the view system. I.e. when someone calls a view function (like tensor.diag(), and not as_strided() directly), then we thread the information that the view is a "safe as strided". This is technically complicated so we prefer solution no.1 (or something else) if possible