Open gabeweisz opened 7 months ago
Ok, I think this is a bug as reported to me offline by one of our devs: the aot.export() method after the 2.3 upgrade is not applying default decompositions. This will result in many issues. I'm working on a fix.
I expect this patch will resolve the issue: https://github.com/nod-ai/SHARK-Turbine/pull/574
(landed)
Here is a small repro:
import torch import torch.nn as nn import shark_turbine.aot as aot
class Split(nn.Module): def forward(self, x: torch.Tensor): return torch.split(x, 2) model = Split() exported = aot.export( model, args=(torch.empty([5,2], dtype=torch.float32),), )
The current commit throws an exception at fx_importer.py::node_val_to_type
That function can be fixed by adding something like this at line 937: if (isinstance(val, List) and all(isinstance(i, TorchFakeTensor) for i in val)): return list(self.get_vtensor_type( i.size(), i.dtype, sparsity=sparsity, mutable=mutable ) for i in val)
However, doing so causes an error in _emit_operation because result_types is now a list of lists of types (with one entry, the list of result types) rather than a list of types