pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.27k stars 21.56k forks source link

[dynamo] aten::sort() is not supported by dynamo_export #127831

Open borisfom opened 1 month ago

borisfom commented 1 month ago

🐛 Describe the bug

Here, I can't export trivial net with sort() using dynamo_export(). Legacy onnx.export() works fine, dynamo_export errors out :

torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.sort.default']}.

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sort(x, dim=-1, descending=False)

device = torch.device('cuda')
model = Model().to(device)
x = torch.rand(1024, 20, 16).to(device)

# This does not help                                                                                                                                      
# model = torch.export.export(model, (x,), strict=False).run_decompositions()                                                                             

# This works 
onnx_program = torch.onnx.export(model, x, "legacy_sort.onnx")

options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(model, x, export_options=options)
onnx_program.save('model.onnx')

Versions

Pytorch nightly 06/03

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

titaiwangms commented 1 month ago

@shubhambhokare1 a missing op, could you take this?

borisfom commented 1 month ago

Not sure if related: https://github.com/pytorch/pytorch/issues/125633

gramalingam commented 3 weeks ago

Hmmm. Ideally, ONNX should add a sort op. It doesn't have one. I wonder what the legacy exporter converts it to.

titaiwangms commented 3 weeks ago

Hmmm. Ideally, ONNX should add a sort op. It doesn't have one. I wonder what the legacy exporter converts it to.

https://github.com/pytorch/pytorch/blob/c63ccead5effbacfa40db14f11b63b16d3996aaf/torch/onnx/symbolic_helper.py#L790-L807