llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.29k stars 474 forks source link

Transpose error when decomposing the convolution backward #1772

Open ZihengJiang opened 1 year ago

ZihengJiang commented 1 year ago

FX Graph:

def forward(self, primals_1, primals_8, tangents_1):                                                                                                                                                                                                                              convolution_backward = torch.ops.aten.convolution_backward.default(tangents_1, primals_8, primals_1, [16], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]);  tangents_1 = primals_8 = primals_1 = None
    getitem = convolution_backward[0]
    getitem_1 = convolution_backward[1]
    getitem_2 = convolution_backward[2];  convolution_backward = None
    return [getitem_1, getitem_2, None, None, None, None, None, getitem]

Converted torch dialect

module attributes {torch.debug_module_name = "GraphModule"} {
  func.func @forward(%arg0: !torch.vtensor<[16,3,5,5],f32>, %arg1: !torch.vtensor<[2,3,200,200],f32>, %arg2: !torch.vtensor<[2,16,198,198],f32>) -> (!torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32>) {
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %none = torch.constant.none
    %false = torch.constant.bool false
    %true = torch.constant.bool true
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
    torch.runtime.assert %true, "unimplemented: only strides of 1 supported."
    torch.runtime.assert %true, "unimplemented: only strides of 1 supported."
    torch.runtime.assert %true, "unimplemented: only dilations of 1 supported."
    torch.runtime.assert %true, "unimplemented: only dilations of 1 supported."
    %2 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.flip %arg0, %2 : !torch.vtensor<[16,3,5,5],f32>, !torch.list<int> -> !torch.vtensor<[16,3,5,5],f32>
    %4 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
    %6 = torch.aten.convolution %arg2, %5, %none, %0, %4, %0, %false, %1, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.vtensor<[16,3,5,5],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[2,3,200,200],f32>
    %7 = torch.aten.transpose.int %arg2, %int0, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,16,198,198],f32>
    %8 = torch.aten.transpose.int %arg1, %int0, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,200,200],f32>
    %9 = torch.aten.convolution %8, %7, %none, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.vtensor<[2,16,198,198],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
    %10 = torch.aten.transpose.int %9, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
    %11 = torch.prim.ListConstruct %int0, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %12 = torch.aten.sum.dim_IntList %arg2, %11, %false, %none : !torch.vtensor<[2,16,198,198],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[16],f32>
    return %10, %12, %6 : !torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32>
  }
}

It seems that the TransposeInt's type has not been set correctly? https://github.com/llvm/torch-mlir/blob/main/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L1444-L1445

    %5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
ramiro050 commented 1 year ago

Ah good catch. Yeah, a new type should be created in the decomposition for the transposed weight