Open soulitzer opened 1 year ago
We don't track views at all right? ._is_view()
will be False
as well?
That sounds ok to me: you are under autograd so you can't do anything autograd related. And escaping objects into global state from torch_dispatch level is very shady anyways.
Curious what @ezyang thinks though!
Smuggling tensor out of the mode is very naughty and I don't really see how we can even make it work in principle.
The context is that smuggling tensor out of a mode is how xformers currently implements selective checkpoint - it uses TorchDispatchModes to implement a cache under autograd.
The issue is that if you cache the output of some operation, and then perform in-place afterwards, thats like mutating a tensor saved for backwards except you'll get silently incorrect results this time because there's no version counter to protect you. We want to produce an error for this case.
~If we're okay with implementing selective checkpoint as part of checkpoint itself, we can avoid doing naughty things entirely. This might be a good thing for other reasons - I can see selective checkpoint working better with other checkpoint features like early stopping if implemented this way.~
~If we want to keep the selective checkpoint separate,~ I don't see a good way around tensor smuggling. In that case the workaround for the in-place issue is to for all in-place operations to check if we're modifying something in the cache.
cc @fmassa
So, just to confirm, you want to have selective checkpoint check the VC on the smuggled tensors so that it can check if mutation occurred.
The root cause of this problem is that the VC mechanism should live "lower" than autograd, but it currently lives at autograd for efficiency reasons. One possibility is to provide a lower level version counter service that is not tied to autograd. Another is to use @bdhirsh's upcoming mode at any dispatch key stuff to get xformers thing to live after autograd but before ADInplaceOrView. Yet another is to refactor where version counting is done so that non-autograd mechanisms can hook into it (but this might be quite involved / involve slowing down regular code.) Yet another is to have another variant of version counting that's stored on storage (@kurtamohler's storage PyObject preservation will help) which then would always be stable no matter where the VC setup is done.
So, just to confirm, you want to have selective checkpoint check the VC on the smuggled tensors so that it can check if mutation occurred.
Yeah
Another solution (short-term): just propagate the version counter instead of passing 0 (see below)? This solution won't decouple the version counter mechanism entirely from autograd (incrementing version counter on in-place is still handled by autograd, but maybe that is we want, for efficiency purposes?), but does seem to fix this particular issue.
Tensor detach(const Tensor& self) {
// NB: detach() is not the same thing as alias()! The main difference is that
// detach does not allow metadata change while alias does.
return Tensor(self.getIntrusivePtr()->shallow_copy_and_detach(
// NB: The ADInplaceOrView logic will overwrite these with the
// appropriate values if it runs; otherwise these are the values.
/*version_counter=*/0, // update to be self.getIntrusivePtr()->version_counter()
/*allow_tensor_metadata_change=*/false));
}
(
I'm cool for this, esp if we can avoid the overwrite when it's not necessary
Since this has come up again with integrating subclasses within PT2 and fp8 integration, I think we want to have a full solution.
One proposal I would have would be:
mark_as_view(t, base, is_bw_differentiable=True, is_fw_differentiable=True, view_fn=None, creation_meta=CM.Default)
function. This is the equivalent of our VariableTypeUtils.h
as_view
function (it will always work in inplace with the right asserts to ensure there is no pre-existing autograd metadata on the input). It can be used to add autograd view metadata to any Tensor from Python code.It is important to keep in mind that for regular use of dispatch class/mode, the ADInplaceOrView layer that we provide does add the autograd metadata properly.
There is a separate question about making the storage from these view alias properly. In this case, I think that we should:
.set_()
to make storage aliasing when working below autograd. This can be added to the mark_as_view()
method above for convenience.Also cc @zou3519 for the custom op part of this discussion.
Another option might be to have users reenable ADInplaceOrView manually:
with torch._C._UnexcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)):
saved_b = out.view_as(out)
Note that _UnexcludeDispatchKeyGuard doesn't actually exist yet, but it should be trivial to write if we want it.
Edit:
I think exposing as_view
to python is more useful for the custom op use case where I'd like to register a new view op and want to define its ADInplaceOrView kernel, but if I already have an existing view op, it might be better to reuse the existing ADInplaceOrView kernel.
A downside of offering a context manager though is that it could lead to ADInplaceOrView still being active when you reach the backend if you dispatch into a non-view or non-in-place inside the context manager, leading to silently worse performance. Perhaps we can offer the context manager as a part of a wrapper function instead, e.g. out = with_tracking(torch.view)(inp)
.
A question that applies to both proposals is what happens if I dispatch into torch dispatch for an existing view op, lets say torch.ops.aten.view.default
, and our __torch_dispatch__
function returned a tensor that already has view metadata,
what should the existing ADInplaceOrView kernel do?
There is a separate question about making the storage from these view alias properly.
Views performed in Python key do already do result in storages that are aliased, if this is what you mean
a = torch.tensor(1.)
with SaveATensorMode():
b = torch.sin(a)
assert b.data_ptr() == saved_b.data_ptr()
Views performed in Python key do already do result in storages that are aliased, if this is what you mean
Well, did you check that data_ptr() value, it is 0
for every Tensor created with make_wrapper_subclass IIRC.
A question that applies to both proposals is what happens if I dispatch into torch dispatch for an existing view op, lets say torch.ops.aten.view.default, and our __torch_dispatch__ function returned a tensor that already has view metadata, what should the existing ADInplaceOrView kernel do?
I think a good first draft for this API would be:
Usually the ADInplaceOrView kernel is responsible for handling this, but since we're operating under autograd, the version counter information is not correctly propagated.
Previously this was probably intentional for inference mode, but it may be useful to support version counter propagation for
__torch_dispatch__
use cases to prevent common silently correctness issues.If this were a Tensor subclass, we could probably enable_reentrant_dispatch, but that may not work for modes (can we fix this?).
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @Lezcano @Varal7 @Chillee @samdow