llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.34k stars 499 forks source link

0-dim shapes seen from aten.slice.tensor shape resolution #716

Closed sjarus closed 2 years ago

sjarus commented 2 years ago

For TOSA backend testing, we lowered several tests, whose legalizations fail due to instances where aten.slice.tensor emits an output whose shape has at least one 0-dim .

The simplest way to repro this is to create a static shaped IouModule, i.e.

class IouOfModuleStatic(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([1024, 4], torch.float32, True),
        ([1024, 4], torch.float32, True),
    ])
    def forward(self, bbox1, bbox2):
        area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1])
        area2 = (bbox2[:, 2] - bbox2[:, 0]) * (bbox2[:, 3] - bbox2[:, 1])
        lt = torch.maximum(bbox1[:, :2], bbox2[:, :2])
        rb = torch.minimum(bbox1[:, 2:], bbox2[:, 2:])

        overlap_coord = (rb - lt).clip(0)
        overlap = overlap_coord[:, 0] * overlap_coord[:, 1]
        union = area1 + area2 - overlap

        return overlap / union

@register_test_case(module_factory=lambda: IouOfModuleStatic())
def IouOfModuleStatic_basic(module, tu: TestUtils):
    module.forward(tu.rand(1024, 4), tu.rand(1024, 4))

Running this with the e2e script and --config=tosa should emit a file /tmp/IouOfModuleStatic.mlir which will show something like:

%39 = torch.aten.slice.Tensor %38, %int1, %none, %int2, %int1 : !torch.vtensor<[1024,4],f32>, !torch.int, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[1024,0],f32> loc(#loc13)

This is unparseable and emits legalization failures like:

$MYHOME/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/vision_models.py:167:25:
 error: 'tosa.sub' op operands don't have broadcast-compatible shapes
        overlap_coord = (rb - lt).clip(0)
                        ^
$MYHOME/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/vision_models.py:167:25: 
note: see current operation: %92 = "tosa.sub"(%89, %91) : (tensor<1024x2xf32>, tensor<1024x0xf32>) -> tensor<1024x2xf32>

This appears to be an off-by-1 error in slice shape computation. Where does Torch-MLIR do this computation ? A fix could potentially unblock a bunch of TOSA full network cases, and probably also impacts the RefBackend path.

dan-garvey commented 2 years ago

This appears to be an off-by-1 error in slice shape computation. Where does Torch-MLIR do this computation ?

https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchToLinalg/DataMovement.cpp#L717 edit: actually this isn't where the logic we're looking for lives, as this doesn't even apply to the TOSA path

I'll try your test case. Off the top of my head, I know that pytorch itself does allow for 0-dim in tensor shapes, so we modeled the slice lowering to allow for it also.

edit: So I did confirm this static version of the Iou test case does fail refbackend as well, so something about how the static shapes are used (maybe in RefineTypes?) is causing the issue (as the dynamic version passes). Could someone more familiar with how static shape information is used chime in on this or point me in the right direction?

silvasean commented 2 years ago

I have a fix for the shape inference issue (will post PR after meetings).

sjarus commented 2 years ago

This appears to be an off-by-1 error in slice shape computation. Where does Torch-MLIR do this computation ?

https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchToLinalg/DataMovement.cpp#L717 edit: actually this isn't where the logic we're looking for lives, as this doesn't even apply to the TOSA path

I'll try your test case. Off the top of my head, I know that pytorch itself does allow for 0-dim in tensor shapes, so we modeled the slice lowering to allow for it also.

edit: So I did confirm this static version of the Iou test case does fail refbackend as well, so something about how the static shapes are used (maybe in RefineTypes?) is causing the issue (as the dynamic version passes). Could someone more familiar with how static shape information is used chime in on this or point me in the right direction?

Yeah this isn't specifically about TOSA - it just happens that we identified it in a failed TOSA legalization and worked back to the defining op of the aten.sub that would not legalize due to the 0-dim shape. The defining op was the slice. I didn't test the RefBackend but I expected it to fail as you've confirmed.

RefineTypes was the first place I looked, but I could not find something like visitAtenSliceTensorOp() there that was miscomputing something.

silvasean commented 2 years ago

Shapes are now computed with the shape library: https://github.com/llvm/torch-mlir/blob/main/docs/shape_lib.md

silvasean commented 2 years ago

@dan-garvey after my fix in https://github.com/llvm/torch-mlir/pull/721 there still seems to be some issue with affine maps in IouOfModuleStatic when lowering linalg to buffers -- can you take a look and add that test with the fix?