Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Handling inplace through SSA #145

Open t-vi opened 5 months ago

t-vi commented 5 months ago

This issue is to facilitate discussion of inplace handling, namely the "big" solution of having a static single assignment (SSA) representation.

For any handling of inplace, we want to make certain that two things are achieved:

Some thoughts from video/chat discussions:

About the problem:

Solution considerations:

Later versions could refine the alias analysis as needed.

@tfogal @mruberry @IvanYashchuk

mruberry commented 5 months ago

One thing I've struggled with in these discussions is how does static single assignment form address the challenges of inplace operations? Doesn't Thunder's IR essentially already have the SSA property (I think there are some cases where we model operations which do nothing as the equivalent of x = x, but I don't think they're a problem)?

t-vi commented 5 months ago

Well, we do have SSA form in the Thunder IR until we allow inplace operations, which seems to be one of the things people want to do. The transformation to SSA is just there to make sure we don't run into correctness issues.

mruberry commented 5 months ago

Well, we do have SSA form in the Thunder IR until we allow inplace operations, which seems to be one of the things people want to do. The transformation to SSA is just there to make sure we don't run into correctness issues.

OK; I guess I think about this as "let's handle inplace operations while preserving our SSA form and relying on dataflow to determine the order of operations"?

t-vi commented 5 months ago

Exactly, or even "let's preserve our SSA form and the fact that dataflow describes the order of operations (and admissable reorderings) even when we want inplace".

mruberry commented 5 months ago

Maybe we can to be too conservative with operations like reshape and for ambiguous cases act like they (might have) created a view?

What gets a little tricky is that different executors may have different rules for when they do/don't create views. It's interesting to think about how different executors might communicate this.

Alternatively, maybe we should define semantics like, `if the reshaped tensor shares storage with another tensor, then the result must be a tensor that doesn't share its storage", and we can update the torch executor to respect those semantics?

apaz-cli commented 5 months ago

How does static single assignment form address the challenges of inplace operations?

For the most part, it does so by assuming that these sorts of side effects don't exist. When you convert to SSA, you're meant to implement renaming a variable any time it's potentially modified, and hide any conditional logic, jumps, etc behind phi nodes. That's sort of the central idea of SSA. We have no control flow, so we don't need phi nodes. All that we need is a better way than the variable name for tracking the identity of the tensor.

So you would rename:

a = torch.zeros(5)
b = torch.ones(5)
a.add_(b)
return a

To something like:

t0 = torch.zeros(5)
t1 = torch.ones(5)
t2 = t0.add(t1)
return t2

All that's required to do that is to iterate through the trace, operation by operation, renaming every tensor that's modified. At the end though, you run into a problem. Suppose a was an argument. You need a way to assign back into it.

def foo(a: torch.Tensor):
  # Do other stuff
  b = torch.ones(5)
  a.add_(b)
  return a

becomes

def foo(t0: torch.Tensor):
  # Do other stuff
  t1 = torch.ones(5)
  t2 = t0.add(t1)
  tensor_memcpy(t0, t2) # to, from
  return t0

I'm not sure what we should do in this situation. NVFuser has it figured out. Just write the answer back inside. But we would have to do it with torch ops. If there's a way to perform this memcpy, ideally in a way that can be easily optimized out, let me know. In that case, I think writing a pass that functionalizes this stuff is pretty easy. If there isn't, I'm not sure how to do an SSA-style functionalization pass here.

To support this, we would only have to add either: A) TensorProxies have an identity (which we can disregard after functionalization when the user is writing their passes) B) Symbols contain a list of all the tensor references that they write to

Either would work.

mruberry commented 5 months ago

That's the thing about functionalizing -- it alone is not enough. If that's the approach that's taken then there must be a later pass that has information not present in the Python program which can update the memory as needed. The situation is also more complicated than a straightforward example, like the one above, suggests. If a tensor is written to inplace then how many tensors are modified? Maybe one, maybe two, maybe ten. How are these relationships expressed with a functionalization pass? I guess the operation would have to be repeated n times? And it's unclear if such programs could be optimized as well as programs that can express their inplace logic directly.

IvanYashchuk commented 5 months ago

As a reminder PyTorch returns new Tensors for inplace operations, the return isn't just None. So blocks of code like

a.add_(b)
return a

could equivalently be rewritten as

c = a.add_(b)
return c

and with this rewrite, maybe we don't need any special trace pass to reconstruct the dataflow?

t-vi commented 5 months ago

I am not sure that "there is a PyTorch programming style that makes it easy for us" helps us that much because silently producing incorrect results for valid PyTorch code unless users stick to some subset seems not a good option.

If you don't have better ideas, I would like to have

  a = torch.randn(4)
  b = a.view(2, 2)[1]
  b += 1

to translate to something along the lines of

   a = torch.randn(4)
   b = a.view(2, 2)[1]
   c = b + 1
   b(new) = copy_(b, c)
   a(new) = _update_storage_(a, b(new))

with b(new) and a(new) being new proxies. An optimizer can then take the copy_ and fuse it to the computation. The main nuisances I see are

apaz-cli commented 5 months ago

@IvanYashchuk yes, but a is c == True. It just returned a, and both python variables point to the same object.

The problem that I think @mruberry is referring to is the problem of views. In general, it's really difficult to know how many tensors are actually being modified by an inplace operation. Consider the following:

a = torch.zeros(5)
b = a[0:1]
c = a[1:3]
b[0] = 1
c[2] = 1
print(a is b) # False
print(a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()) # True
print(a) # [1, 0, 1, 0, 0]

How do we know when we write to the backing storage of b that we have to rename a? And can we find a way not to rename b when we write to c? It's a famous open problem in pytorch.

I'm... not really saying that I have a general solution for the "how many tensors am I actually changing when I do an inplace operation on a veiw" problem. But I do think that we can warn users when they do inplace operations on tensors that we know are a view. And I also think that the other cases are pretty easy. If they still aren't, let me know that I'm wrong.

I see that @t-vi just posted, and I think we independently came to largely the same conclusion. It should be easy enough to tag which tensors are views. So, we'd have to add two bits of info to meta functions. Or to tensors, or some combination. Either works.

mruberry commented 5 months ago

... I would like to have

  a = torch.randn(4)
  b = a.view(2, 2)[1]
  b += 1

to translate to something along the lines of

   a = torch.randn(4)
   b = a.view(2, 2)[1]
   c = b + 1
   b(new) = copy_(b, c)
   a(new) = _update_storage_(a, b(new))

with b(new) and a(new) being new proxies. An optimizer can then take the copy_ and fuse it to the computation. The main nuisances I see are

  • before reaching the SSA proxies will need to be versioned in a way to disambiguate the (new) part here,
  • the execution either needs to make sure that b is still a view into a or know which bits to copy to which bits of a,
  • strictly speaking, we would need to preserve the "is view into" property, but it is an implementation detail with reshape (I'd be inclined to treat reshape as always creating a new tensor for this aspect).

The thing I don't like about this approach is that there's nothing in the program that says a cannot be used after the creation of a(new), and if were to introduce such a concept then we'd have to add significant complexity to our existing dataflow-based passes.

A dataflow-based approach to ensure the correctness of these operations might look something like

a, a_storage0 = torch.randn(4)
b, _ = a.view(2, 2, storage=a_storage0)[1]  # maybe b, a_storage0 = would be clearer?
c, a_storage1 = torch.add_(b, 1, storage=a_storage0)

and then if someone performed an operation like a + 1 later it would look like

d, d_storage0 = torch.add(a, 1, storage=a_storage1)

Now obviously that's pretty ugly, and I think there's room for refinement, but if we have some mechanism for accepting and updating storage generations, like in the above, then you can order the operations through dataflow alone. In particular, the a + 1 operation could not be reordered before the inplace add.

I think what's tricky about this approach is thinking about how to best associate views and storage, but the tuple (view, storage) seems like a really good perspective to have.

Edit: we could probably make every current TensorProxy a tuple of view and storage so that calls like d = torch.add(a, 1) were really something like (d_view, d_storage0) = torch.add((a_view, a_storage1), 1) but we wouldn't have to make the split so prominent when printing the program (the tuples could have names)

t-vi commented 5 months ago

Seems like #264 would also benefit from an SSA/functionalization pass, as it also deals with implicit state (except that it seems simpler in that we don't need worry about aliasing).

jjsjann123 commented 5 months ago

Want to log one thing @mruberry mentioned in an offline discussion.

The scope of alias that thunder is trying to support would also include aliases across program inputs.

I think it makes it trickier to handle SSA, since we might not be able to reason how to replay inplace update on aliased inputs.

def foo(a, b):
  c = a.add_(1.0)
  e = b * 2
  return e

assuming a and b are aliases, SSA would need to replay the inplace update as something like this vvv first, (and then deSSA to write the update back to a.buffer)

  a0 = a.add(1.0)
  b0 = b.add(1.0)
  e = b0 * 2
  return e

But if a and b are just overlap, we wouldn't be able to replay a.add_(1.0) on b, unless we know how to model the overlap.

mruberry commented 4 months ago

Want to log one thing @mruberry mentioned in an offline discussion.

The scope of alias that thunder is trying to support would also include aliases across program inputs.

I think it makes it trickier to handle SSA, since we might not be able to reason how to replay inplace update on aliased inputs.

def foo(a, b):
  c = a.add_(1.0)
  e = b * 2
  return e

assuming a and b are aliases, SSA would need to replay the inplace update as something like this vvv first, (and then deSSA to write the update back to a.buffer)

  a0 = a.add(1.0)
  b0 = b.add(1.0)
  e = b0 * 2
  return e

But if a and b are just overlap, we wouldn't be able to replay a.add_(1.0) on b, unless we know how to model the overlap.

I think @t-vi had an interesting idea to have an instruction like

a0, b0 = update(a, b)

which could provide information about the aliasing relationships in the trace, and that might help address this?

jjsjann123 commented 4 months ago

a0, b0 = update(a, b)

Yeah that's what we want to do. The question is that how do we exactly update b0. With the wildest kind of memory overlap between a and b. It'll be pretty tricky trying to figure out how a.add_(1.0) would map to b's storage for that update.

mruberry commented 4 months ago

a0, b0 = update(a, b)

Yeah that's what we want to do. The question is that how do we exactly update b0. With the wildest kind of memory overlap between a and b. It'll be pretty tricky trying to figure out how a.add_(1.0) would map to b's storage for that update.

Agreed! I don't think this problem is, in general, solvable in time polynomial to the shape and strides of the tensors. If the inplace operations are explicitly represented I don't think we have that problem.