pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

`functionalize()` doesn't properly handle aliased program inputs #950

Open bdhirsh opened 2 years ago

bdhirsh commented 2 years ago

One reason that functionalization can't be written as a pure graph transform is that its output can depend on the input metadata - specifically whether or not the program inputs alias each other (so you could probably write it as a graph pass, as long as the graph was annotated with aliasing info about the program inputs).

That case actually isn't properly handled by functionalize() today, but it should be.

Take this example program:

def f(a, b):
    a.add_(1)
    c = b + 1
    return c

If you run functionalization with sample inputs that do not alias each other, then you'd expect to get out a function like below:

def f(a, b):
    a_updated = a.add(1)
    c = b.add(1)
    a.copy_(a_updated)  # functionalization needs to copy the mutation back to the input 
    return c

If the inputs are aliased though, then functionalization should do something different. Below are two examples (the second one is harder to handle than the first):

# Example 1: one input is the "base" of another
x = torch.ones(2)
print(make_fx(functionalize(foo))(x, x.view(-1)).code)

# should print something like...
def f(a, b):
    a_updated = a.add(1)
    # we need to regenerate b from its base (which is "a" in this case)
    b_updated = a_updated.view(-1)
    c = b_updated.add(1)
    a.copy_(a_updated)
    return c
# Example 2: both inputs are slices into some other "base"
x = torch.ones(4)
print(make_fx(functionalize(foo))(x[:3], x[1:]).code)

# we need to regenerate "b" during the program, which requires:
# (1) propagating the mutation to the base
# (2) regenerating "b" from the newly updated base
# To do all of that, we need access to the base!
# That means that the shared base of a and b should be another input into the traced function
def f(a, b, base_a_and_b):
    a_updated = a.add(1)

    base_a_and_b_updated = torch.ops.aten.select_scatter(base_a_and_b, a_updated, 0, 0)
    b_updated = base_a_and_b_updated[1:]
    c = b_updated.add(1)
    a.copy_(a_updated)
    return c

Notice that in the second example, we need to add a new input to the traced graph, that didn't show up in the original program! In some cases, In order to correctly remove aliasing, the program needs access to the base tensor of the inputs - which isn't a direct input into the program, but can be obtained indirectly through tensor._base.

Implementation?

Problem 1: One issue is that we don't have an easy way, given a bunch of tensors that are known to alias, to find their "base" tensor.

The easiest way would probably be to use the _base attribute on the tensor, which autograd conveniently tracks for us. As @ezyang pointed out though, because this info is tracked by autograd, none of it is tracked if you're running your program under inference_mode().

In a lot of cases though, we should be able to "elect" a base. For example, given:

a = x[1:8]
b = x[2:6]
c = x[4:8]
functionalize(foo)(a, b, c)

We should be able to detect inside of functionalize() that "a" can be effectively used as a base of "b" and "c" (I'm not sure exactly what the implementation would look like, but a's sizes/strides/storage offset should tell us that its memory fully covers b's + c's memory).

That still doesn't help us in all situations though, like in example 2 above. Or, if we run functionalize(foo)(b, c) with views given above.

I'm pretty sure that we're forced to rely on the _base attribute in those situations. We could either: (1) Detect that situation, and error if inference_mode() is turned on. (2) Keep inference mode turned off when inside of a torchdynamo-enabled region of the program. inference_mode() is an eager-mode concept, and any overhead that autograd-tracking creates will be traced away when you're running with torchdynamo, so this could be reasonable.

Problem 2: Tensor subclasses

In order to even detect the situation described in this issue, we need to know that two program inputs actually have aliased storages. This can be a problem though, if the inputs to functionalization are ("has-a") tensor subclasses instead of ordinary torch.Tensor objects.

The most important example of this is probably ProxyTensor (and FakeTensor?), since that's what we're using for tracing. The "dumb" way to detect aliasing of two proxy tensors is:

def are_aliases(a: ProxyTensor, b: ProxyTensor):
    # I think this would technically break because `.storage()` in python returns an error if `a.elem` is a meta tensor,
    # but we could probably arrange for a workaround
    return StorageWeakRef(a.elem.storage()) == StorageWeakRef(b.elem.storage())
    # ...
    # or with _base...
    return (a.elem._base is not None and a.elem._base == b.elem)
           or (b.elem._base is not None and b.elem._base == a.elem)

This might be an ok workaround for dynamo tracing, but we'd still end up with silent correctness issues if you're using another "has-a" tensor subclass.

I'm not sure what the best solution to this is (although the problem is pretty niche, since functionalization + custom has-a tensor subclass + aliased program inputs together hopefully aren't too common).

We could require subclasses to implement a def aliases(other: Tensor) -> bool method in order to properly work with functionalization?

cc @Chillee

ezyang commented 2 years ago

Can't we just make a synthetic base? You know the size and the stride of each of the input tensors, you know they share a storage, so you can reinterpret them as views on a freshly created tensor representing the 1D storage.

Also, a.copy_ seems wrong; it seems like you should only copy to the base.

ezyang commented 2 years ago

If you want to get super fancy, you can compute the minimum extent of the original tensor necessary to serve all the views you are given. But I don't think it's important to make this efficient, just to make it correct.

bdhirsh commented 2 years ago

Can't we just make a synthetic base? You know the size and the stride of each of the input tensors, you know they share a storage, so you can reinterpret them as views on a freshly created tensor representing the 1D storage.

Agreed that we can do that in most cases, but not for all cases (in example 2 above). My question is, given that "synthesizing a base" won't work for all cases, should we just not bother with it and use the ._base attribute that we have? (or we could do both - make a best effort to synthesize the base, fall back to ._base if that doesn't work, and only then error if inference mode is on).

Also, a.copy_ seems wrong; it seems like you should only copy to the base.

Oh yes agreed - if we have access to the base as a program input, we should only copy to that and not any of its views (seems fixable)

ezyang commented 2 years ago

Could you explain more why it doesn't work in example 2?

bdhirsh commented 2 years ago

Talked more with Ed offline. Summarizing:

Agreed that we can create synthesize a base in all cases, and we don't need the ._base attribute. How? For all program inputs that alias each other, we have access to their underlying storage. We can generate a fresh base tensor from that storage, and manually re-create the aliasing relationships between that base and all of the aliased inputs.

So the steps are:

(1) group all tensor arguments by their aliases (e.g. with Dict[StorageWeakRef, List[Tensor]]) (2) For every storage that has more than one alias (view1 and view2): (a) Generate a synthetic base tensor base = torch.from_blob(storage, storage.numel()) (b) Add base as a new program input (c) Manually add the view relationship between view1, view2 and base. Inside of functionalize(), we can do this at::functionalization::impl::create_functional_tensor_with_view_meta() when we wrap view1 and view2 into functional tensors

The easiest way to add the view relationships would be with as_strided, but we could probably work harder to figure out what the original view operator was and re-ify that instead.

Also- when we add base as a new program input, we need to take care that the new function emitted by functionalize() has the same signature as the original. Instead, we need to intercept in the middle - probably by internally calling a new function with an extra argument, and passing in any synthetic bases.

laithsakka commented 1 month ago

cc @zou3519 this seems very related to what we are doing?

zou3519 commented 1 month ago

@laithsakka this is something else, AOTAutograd handles aliased program inputs, the functorch.functionalize API (which is different from the functionalization used by AOTAutograd) doesn't.

bdhirsh commented 3 weeks ago

AOTAutograd handles aliased program inputs

for better or worse 😛