Open jjsjann123 opened 2 months ago
This problem can also be properly resolved in prologue trace. i.e. here i1 is unpacked in prologue, because it is consumed by the top level symbol ltorch.getitem. Unfortunately the uses of subsymbol is not considered as consumed by computation trace today, see code, so i0 isn't getting unpacked in prologue yet.
I'm a bit skeptical about the alternative here and my gut feeling is that the main solution (to unpack the shape "close" to where it is used is preferable). To my mind, the alternative solution implies the major change that is that the subsymbols considered as a block of code has inputs that the symbol has not. To my mind, this is tricky on several layers (producers / consumers etc.).
The tricky thing with re-unpacking could be that I'm not sure we are good at having one name assigned to multiple times, so we may need new names every time we do this.
sorry missed this email earlier (gmail access is limited at this moment, depending on how reliable vpn is). I'm working on the automatic unpacking in #1260 .
The tricky thing with re-unpacking could be that I'm not sure we are good at having one name assigned to multiple times, so we may need new names every time we do this.
I'm already hitting this one. In #1260 , every shape query resulted in a prims.shape
in the trace. And since the primitive returns the NumberProxy carried by the tensor, we are assigning a given symbol multiple times in the trace breaking SSA.
e.g. with the following program:
def foo(a):
return torch.reshape(a, [a.numel()]).relu()
We have a trace:
def computation(a, i0, i1):
i2 = operator.mul(1, i0)
i3 = operator.mul(i2, i1)
# ...
t11 = torch.reshape(a, [i3])
# ...
t20 = torch.nn.functional.relu(t11, False)
# ...
I'm seeing a couple issues here:
We could have multiple identical unpacking here in subsymbol. e.g. in the decomposition of relu, we have prims.shape(t11) recorded multiple times. This could also happen when multiple TensorProxy with the same shape being queried. I think this can be cleaned up with a CSE/DCE through subsymbols so we'll just keep a single query.
# t39 = ltorch.gt(t11, 0)
# (i3,) = prims.shape(t11)
# (i3,) = prims.shape(t11)
# ...
# t39 = prims.gt(t11, 0.0)
We have unpacking of # (i0, i1) = prims.shape(a)
in the subsymbol, which is already unpacked in prologue. This could also happen for shape queries at different level. These won't be an issue until we flatten the symbol. I think this could be something that we just leave for the code that calls flattening to handle (like fusion pass?)
t11 = torch.reshape(a, [i3])
# t11 = ltorch.reshape(a, [i3])
# (i0, i1) = prims.shape(a)
# b14 = prims.eq(i0, i3)
# b15 = prims.ge(i3, 0)
# i18 = prims.mul(1, i3)
# t11 = prims.reshape(a, (i3,))
🐛 Bug
TensorProxy.shape
remains as an attribute, hence accessing it won't leave an unpack in trace. This causes issues when we haveNumberProxy
inTensorProxy.shape
.In #1201 commit 26f883e33f2231c0f04fa5443a5790e892e93c78. I have to rely on this hack. Otherwise, grad transform would see an invalid trace,
e.g. in a trivial slice:
Without the explicit unpack of a.shape, the subsymbols in ltorch.getitem would access
i0
, which is implicitly carried by a.shape but not explicitly in the trace.Alternative
This problem can also be properly resolved in prologue trace. i.e. here
i1
is unpacked in prologue, because it is consumed by the top level symbolltorch.getitem
. Unfortunately the uses of subsymbol is not considered as consumed by computation trace today, see code, soi0
isn't getting unpacked in prologue yet.So for input TensorProxy, I think prologue unpacking is the right choice here. For intermediate tensor, it might be a mixed solution, which goes back to the conversation we have in #1133 .