nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

New torch export path does not support torch.split (and possibly other operations returning a list of tensors) #565

Open gabeweisz opened 7 months ago

gabeweisz commented 7 months ago

Here is a small repro:

import torch import torch.nn as nn import shark_turbine.aot as aot

class Split(nn.Module): def forward(self, x: torch.Tensor): return torch.split(x, 2) model = Split() exported = aot.export( model, args=(torch.empty([5,2], dtype=torch.float32),), )

The current commit throws an exception at fx_importer.py::node_val_to_type

That function can be fixed by adding something like this at line 937: if (isinstance(val, List) and all(isinstance(i, TorchFakeTensor) for i in val)): return list(self.get_vtensor_type( i.size(), i.dtype, sparsity=sparsity, mutable=mutable ) for i in val)

However, doing so causes an error in _emit_operation because result_types is now a list of lists of types (with one entry, the list of result types) rather than a list of types

stellaraccident commented 7 months ago

Ok, I think this is a bug as reported to me offline by one of our devs: the aot.export() method after the 2.3 upgrade is not applying default decompositions. This will result in many issues. I'm working on a fix.

stellaraccident commented 7 months ago

I expect this patch will resolve the issue: https://github.com/nod-ai/SHARK-Turbine/pull/574

stellaraccident commented 7 months ago

(landed)