Open zou3519 opened 1 year ago
Just make https://github.com/pytorch/pytorch/blob/e4ea751810bd1b27a105ac43ce2c8c84fabc1167/c10/core/TensorImpl.h#L1084 return false
for BatchedTensor and this will go away! :)
Hmm you mean GradWrapper right?
How does this work? Is there special logic in forward-mode AD that handles support_as_strided
?
Its not forward AD specific. There's logic in ADInplaceOrView to check for the tensor's support_as_strided method, so this would apply to all views.
There is special logic in all of autograd for this :)
It basically will replace all the places where we would usually call as_strided()
to now call the original view op.
We use this to be able to handle conjuate view, cross dtype views (which can't be replaced with as_strided) or nested tensor (which can't handle generic as_strided)
@albanD do you have a sense of how much overhead this adds?
Making this return false
for BatchedTensor doesn't actually work because BatchedTensor isn't directly involved in autograd -- autograd sees the TensorWrapper / GradWrapper. As Jeffrey mentioned we would have to set support_as_strided=False
for GradWrapper, which would mean that even if vmap is not involved (e.g. the user just uses functorch.{jvp, grad}
), they would take the performance hit.
The difference is noticeable for very small ops but not a dealbreaker either:
In [12]: a = torch.view_as_real(torch.rand(2, dtype=torch.complex64, requires_grad=True).clone())
In [13]: b = torch.rand(4, requires_grad=True).clone().view(2, 2)
In [14]: %timeit tmp = a.view_as(a)
969 ns ± 2.32 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [15]: %timeit tmp = b.view_as(b)
866 ns ± 12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [16]: %timeit tmp = a.add_(1)
3.27 µs ± 11 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [17]: %timeit tmp = b.add_(1)
2.91 µs ± 14.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
The first one is the cost to track a full view op instead of nothing. The second is replaying all the views instead of just one as_strided.
Thanks Alban. A few hundred nanoseconds is not that bad
Had some more offline discussion with Alban.
It's important to note that:
So, here's the current plan on record:
1) First we should see if we can easily prove the as_strided is "safe". If we can, then no problem, we can write a batching rule for it. 2) If it is not easy to prove the as_strided is "safe", then we may need to thread that information through the view system. I.e. when someone calls a view function (like tensor.diag(), and not as_strided() directly), then we thread the information that the view is a "safe as strided". This is technically complicated so we prefer solution no.1 (or something else) if possible
The Problem
raises the following:
If I am understanding what is going on correctly, the root cause of the problem is that, ignoring vmap for a second, in
x.copy_(y)
, x is a regular Tensor and y is a dual tensor:view.tangent
gets assignedx.tangent.as_strided(something)
Now, if
y.tangent
is a BatchedTensor, then callingas_strided
on it may raise the above error message.Is this actually a problem?
Previously, our approach was to say that vmap x jvp composition only works when the user must only vmap over dimension 0. However, that's not quite correct -- if the user users non-contiguous tensors, then it'll run into this problem. Also, vmap x jvp can produce tensors where the batch dimension is not at 0, so the user has no control over this.
Potential solutions
as_strided
, we should call the original view operation. This means we should save the original view operation somewhere.safe_as_strided
operator, (b) save some metadata on if a view Tensor was created from a base through a chain of "safe" operations or not, and (c) dispatch to eithersafe_as_strided
oras_strided
Thoughts? cc @soulitzer @albanD