Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.21k stars 81 forks source link

TensorProxy.shape should be unpacked automatically #1253

Open jjsjann123 opened 2 months ago

jjsjann123 commented 2 months ago

🐛 Bug

TensorProxy.shape remains as an attribute, hence accessing it won't leave an unpack in trace. This causes issues when we have NumberProxy in TensorProxy.shape.

In #1201 commit 26f883e33f2231c0f04fa5443a5790e892e93c78. I have to rely on this hack. Otherwise, grad transform would see an invalid trace,

e.g. in a trivial slice:

def foo(a):
  return a[..., : a.shape[-1]]

thunder.jit(foo, cache="symbolic values")
def computation(a, i1):
  # a: "cpu f32 [IntegerProxy name=i0], [IntegerProxy name=i1]
  # i1: "int 8"
  a = ltorch.getitem(a, (..., slice(None, i1, None))
      # (i0, i1) = prims.shape(a)  # THIS IS WHAT THE HACK DOES
      # b2 = prims.lt(i0, 0)
      # ...

Without the explicit unpack of a.shape, the subsymbols in ltorch.getitem would access i0, which is implicitly carried by a.shape but not explicitly in the trace.

Alternative

This problem can also be properly resolved in prologue trace. i.e. here i1 is unpacked in prologue, because it is consumed by the top level symbol ltorch.getitem. Unfortunately the uses of subsymbol is not considered as consumed by computation trace today, see code, so i0 isn't getting unpacked in prologue yet.

So for input TensorProxy, I think prologue unpacking is the right choice here. For intermediate tensor, it might be a mixed solution, which goes back to the conversation we have in #1133 .

t-vi commented 2 months ago

This problem can also be properly resolved in prologue trace. i.e. here i1 is unpacked in prologue, because it is consumed by the top level symbol ltorch.getitem. Unfortunately the uses of subsymbol is not considered as consumed by computation trace today, see code, so i0 isn't getting unpacked in prologue yet.

I'm a bit skeptical about the alternative here and my gut feeling is that the main solution (to unpack the shape "close" to where it is used is preferable). To my mind, the alternative solution implies the major change that is that the subsymbols considered as a block of code has inputs that the symbol has not. To my mind, this is tricky on several layers (producers / consumers etc.).

The tricky thing with re-unpacking could be that I'm not sure we are good at having one name assigned to multiple times, so we may need new names every time we do this.

jjsjann123 commented 1 month ago

sorry missed this email earlier (gmail access is limited at this moment, depending on how reliable vpn is). I'm working on the automatic unpacking in #1260 .

The tricky thing with re-unpacking could be that I'm not sure we are good at having one name assigned to multiple times, so we may need new names every time we do this.

I'm already hitting this one. In #1260 , every shape query resulted in a prims.shape in the trace. And since the primitive returns the NumberProxy carried by the tensor, we are assigning a given symbol multiple times in the trace breaking SSA.

e.g. with the following program:

def foo(a):
  return torch.reshape(a, [a.numel()]).relu()

We have a trace:

def computation(a, i0, i1):
  i2 = operator.mul(1, i0)
  i3 = operator.mul(i2, i1)
  # ...
  t11 = torch.reshape(a, [i3])
  # ...
  t20 = torch.nn.functional.relu(t11, False)
  # ...

I'm seeing a couple issues here:

  1. We could have multiple identical unpacking here in subsymbol. e.g. in the decomposition of relu, we have prims.shape(t11) recorded multiple times. This could also happen when multiple TensorProxy with the same shape being queried. I think this can be cleaned up with a CSE/DCE through subsymbols so we'll just keep a single query.

    # t39 = ltorch.gt(t11, 0)
    # (i3,) = prims.shape(t11)
    # (i3,) = prims.shape(t11)
    # ...
    # t39 = prims.gt(t11, 0.0)
  2. We have unpacking of # (i0, i1) = prims.shape(a) in the subsymbol, which is already unpacked in prologue. This could also happen for shape queries at different level. These won't be an issue until we flatten the symbol. I think this could be something that we just leave for the code that calls flattening to handle (like fusion pass?)

    t11 = torch.reshape(a, [i3])
    # t11 = ltorch.reshape(a, [i3])
      # (i0, i1) = prims.shape(a)
      # b14 = prims.eq(i0, i3)
      # b15 = prims.ge(i3, 0)
      # i18 = prims.mul(1, i3)
      # t11 = prims.reshape(a, (i3,))