Open ZihengJiang opened 1 year ago
FX Graph:
def forward(self, primals_1, primals_8, tangents_1): convolution_backward = torch.ops.aten.convolution_backward.default(tangents_1, primals_8, primals_1, [16], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]); tangents_1 = primals_8 = primals_1 = None getitem = convolution_backward[0] getitem_1 = convolution_backward[1] getitem_2 = convolution_backward[2]; convolution_backward = None return [getitem_1, getitem_2, None, None, None, None, None, getitem]
Converted torch dialect
module attributes {torch.debug_module_name = "GraphModule"} { func.func @forward(%arg0: !torch.vtensor<[16,3,5,5],f32>, %arg1: !torch.vtensor<[2,3,200,200],f32>, %arg2: !torch.vtensor<[2,16,198,198],f32>) -> (!torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32>) { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %none = torch.constant.none %false = torch.constant.bool false %true = torch.constant.bool true %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> %1 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> torch.runtime.assert %true, "unimplemented: only strides of 1 supported." torch.runtime.assert %true, "unimplemented: only strides of 1 supported." torch.runtime.assert %true, "unimplemented: only dilations of 1 supported." torch.runtime.assert %true, "unimplemented: only dilations of 1 supported." %2 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int> %3 = torch.aten.flip %arg0, %2 : !torch.vtensor<[16,3,5,5],f32>, !torch.list<int> -> !torch.vtensor<[16,3,5,5],f32> %4 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int> %5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32> %6 = torch.aten.convolution %arg2, %5, %none, %0, %4, %0, %false, %1, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.vtensor<[16,3,5,5],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[2,3,200,200],f32> %7 = torch.aten.transpose.int %arg2, %int0, %int1 : !torch.vtensor<[2,16,198,198],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,16,198,198],f32> %8 = torch.aten.transpose.int %arg1, %int0, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,200,200],f32> %9 = torch.aten.convolution %8, %7, %none, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[2,3,200,200],f32>, !torch.vtensor<[2,16,198,198],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[16,3,5,5],f32> %10 = torch.aten.transpose.int %9, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32> %11 = torch.prim.ListConstruct %int0, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> %12 = torch.aten.sum.dim_IntList %arg2, %11, %false, %none : !torch.vtensor<[2,16,198,198],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[16],f32> return %10, %12, %6 : !torch.vtensor<[16,3,5,5],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[2,3,200,200],f32> } }
It seems that the TransposeInt's type has not been set correctly? https://github.com/llvm/torch-mlir/blob/main/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L1444-L1445
%5 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[16,3,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[16,3,5,5],f32>
Ah good catch. Yeah, a new type should be created in the decomposition for the transposed weight
FX Graph:
Converted torch dialect
It seems that the TransposeInt's type has not been set correctly? https://github.com/llvm/torch-mlir/blob/main/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L1444-L1445