pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.11k stars 22.41k forks source link

views created in __torch_dispatch__ share storage but not version_counter #96319

Open soulitzer opened 1 year ago

soulitzer commented 1 year ago

Usually the ADInplaceOrView kernel is responsible for handling this, but since we're operating under autograd, the version counter information is not correctly propagated.

saved_b = None

class SaveATensorMode(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        global saved_b
        kwargs = {} if kwargs is None else kwargs
        out = func(*args, **kwargs)
        if func == torch.ops.aten.sin.default:
          saved_b = out.view_as(out)
        return out

a = torch.tensor(1.)
with SaveATensorMode():
    b = torch.sin(a)

assert b.data_ptr() == saved_b.data_ptr()

old_b_version = b._version
old_saved_b_version = saved_b._version
b.mul_(2)

print(b._version > old_b_version)  # True
print(saved_b._version > old_saved_b_version)  # False

old_b_version = b._version
old_saved_b_version = saved_b._version
saved_b.mul_(2)

print(b._version > old_b_version)  # False
print(saved_b._version > old_saved_b_version)  # True

Previously this was probably intentional for inference mode, but it may be useful to support version counter propagation for __torch_dispatch__ use cases to prevent common silently correctness issues.

If this were a Tensor subclass, we could probably enable_reentrant_dispatch, but that may not work for modes (can we fix this?).

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @Lezcano @Varal7 @Chillee @samdow

albanD commented 1 year ago

We don't track views at all right? ._is_view() will be False as well?

That sounds ok to me: you are under autograd so you can't do anything autograd related. And escaping objects into global state from torch_dispatch level is very shady anyways.

Curious what @ezyang thinks though!

ezyang commented 1 year ago

Smuggling tensor out of the mode is very naughty and I don't really see how we can even make it work in principle.

soulitzer commented 1 year ago

The context is that smuggling tensor out of a mode is how xformers currently implements selective checkpoint - it uses TorchDispatchModes to implement a cache under autograd.

The issue is that if you cache the output of some operation, and then perform in-place afterwards, thats like mutating a tensor saved for backwards except you'll get silently incorrect results this time because there's no version counter to protect you. We want to produce an error for this case.

~If we're okay with implementing selective checkpoint as part of checkpoint itself, we can avoid doing naughty things entirely. This might be a good thing for other reasons - I can see selective checkpoint working better with other checkpoint features like early stopping if implemented this way.~

~If we want to keep the selective checkpoint separate,~ I don't see a good way around tensor smuggling. In that case the workaround for the in-place issue is to for all in-place operations to check if we're modifying something in the cache.

cc @fmassa

ezyang commented 1 year ago

So, just to confirm, you want to have selective checkpoint check the VC on the smuggled tensors so that it can check if mutation occurred.

The root cause of this problem is that the VC mechanism should live "lower" than autograd, but it currently lives at autograd for efficiency reasons. One possibility is to provide a lower level version counter service that is not tied to autograd. Another is to use @bdhirsh's upcoming mode at any dispatch key stuff to get xformers thing to live after autograd but before ADInplaceOrView. Yet another is to refactor where version counting is done so that non-autograd mechanisms can hook into it (but this might be quite involved / involve slowing down regular code.) Yet another is to have another variant of version counting that's stored on storage (@kurtamohler's storage PyObject preservation will help) which then would always be stable no matter where the VC setup is done.

soulitzer commented 1 year ago

So, just to confirm, you want to have selective checkpoint check the VC on the smuggled tensors so that it can check if mutation occurred.

Yeah

Another solution (short-term): just propagate the version counter instead of passing 0 (see below)? This solution won't decouple the version counter mechanism entirely from autograd (incrementing version counter on in-place is still handled by autograd, but maybe that is we want, for efficiency purposes?), but does seem to fix this particular issue.

Tensor detach(const Tensor& self) {
  // NB: detach() is not the same thing as alias()! The main difference is that
  // detach does not allow metadata change while alias does.
  return Tensor(self.getIntrusivePtr()->shallow_copy_and_detach(
    // NB: The ADInplaceOrView logic will overwrite these with the
    // appropriate values if it runs; otherwise these are the values.
    /*version_counter=*/0, // update to be self.getIntrusivePtr()->version_counter()
    /*allow_tensor_metadata_change=*/false));
}

(

ezyang commented 1 year ago

I'm cool for this, esp if we can avoid the overwrite when it's not necessary

albanD commented 1 year ago

Since this has come up again with integrating subclasses within PT2 and fp8 integration, I think we want to have a full solution.

One proposal I would have would be:

It is important to keep in mind that for regular use of dispatch class/mode, the ADInplaceOrView layer that we provide does add the autograd metadata properly.

There is a separate question about making the storage from these view alias properly. In this case, I think that we should:

Also cc @zou3519 for the custom op part of this discussion.

soulitzer commented 1 year ago

Another option might be to have users reenable ADInplaceOrView manually:

with torch._C._UnexcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)):
    saved_b = out.view_as(out)

Note that _UnexcludeDispatchKeyGuard doesn't actually exist yet, but it should be trivial to write if we want it.

Edit:

I think exposing as_view to python is more useful for the custom op use case where I'd like to register a new view op and want to define its ADInplaceOrView kernel, but if I already have an existing view op, it might be better to reuse the existing ADInplaceOrView kernel.

A downside of offering a context manager though is that it could lead to ADInplaceOrView still being active when you reach the backend if you dispatch into a non-view or non-in-place inside the context manager, leading to silently worse performance. Perhaps we can offer the context manager as a part of a wrapper function instead, e.g. out = with_tracking(torch.view)(inp).

A question that applies to both proposals is what happens if I dispatch into torch dispatch for an existing view op, lets say torch.ops.aten.view.default, and our __torch_dispatch__ function returned a tensor that already has view metadata, what should the existing ADInplaceOrView kernel do?

soulitzer commented 1 year ago

There is a separate question about making the storage from these view alias properly.

Views performed in Python key do already do result in storages that are aliased, if this is what you mean

a = torch.tensor(1.)
with SaveATensorMode():
    b = torch.sin(a)

assert b.data_ptr() == saved_b.data_ptr()
albanD commented 1 year ago

Views performed in Python key do already do result in storages that are aliased, if this is what you mean

Well, did you check that data_ptr() value, it is 0 for every Tensor created with make_wrapper_subclass IIRC.

albanD commented 1 year ago

A question that applies to both proposals is what happens if I dispatch into torch dispatch for an existing view op, lets say torch.ops.aten.view.default, and our __torch_dispatch__ function returned a tensor that already has view metadata, what should the existing ADInplaceOrView kernel do?

I think a good first draft for this API would be: