ROCm / rocMLIR

123 stars 39 forks source link

Error: 'tensor.expand_shape' op expected dimension 1 of collapsed type to be static value of 320 #1618

Open pfultz2 opened 3 weeks ago

pfultz2 commented 3 weeks ago

There is an error:

Invalid MLIR created: Error: 'tensor.expand_shape' op expected dimension 1 of collapsed type to be static value of 320
Note: see current operation: %14 = "tensor.expand_shape"(%12) <{reassociation = [[0], [1, 2], [3], [4]], static_output_shape = array<i64: 2, 32, 10, 128, 128>}> : (tensor<2x128x128x320xf16>) -> tensor<2x32x10x128x128xf16>

From compiling this mlir program:

Problem: gfx942:sramecc+:xnack- 304     -t f16 -out_datatype f16 -transA false -transB false -g 1 -m 2 -n 1280 -k 320
module {
  func.func @mlir_convolution_reshape_add(%arg0: !migraphx.shaped<2x32x10x128x128xf16, 0x10x1x0x0>, %arg1: !migraphx.shaped<2x4x128x128xf16, 65536x1x512x4>, %arg2: !migraphx.shaped<320x4x3x3xf16, 36x1x12x4>) -> (!migraphx.shaped<2x320x128x128xf16, 5242880x1x40960x320>, !migraphx.shaped<2x32x10x128x128xf16, 5242880x163840x16384x128x1>) attributes {arch = "gfx942:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 304 : i64} {
    %0 = migraphx.convolution %arg1, %arg2 {dilation = [1, 1], group = 1 : i64, padding = [1, 1, 1, 1], padding_mode = 0 : i64, stride = [1, 1]} : <2x4x128x128xf16, 65536x1x512x4>, <320x4x3x3xf16, 36x1x12x4> -> <2x320x128x128xf16, 5242880x1x40960x320>
    %1 = migraphx.reshape %0 {dims = [2, 32, 10, 128, 128]} : <2x320x128x128xf16, 5242880x1x40960x320> -> <2x32x10x128x128xf16, 5242880x163840x16384x128x1>
    %2 = migraphx.add %1, %arg0 : <2x32x10x128x128xf16, 5242880x163840x16384x128x1>, <2x32x10x128x128xf16, 0x10x1x0x0> -> <2x32x10x128x128xf16, 5242880x163840x16384x128x1>
    return %0, %2 : !migraphx.shaped<2x320x128x128xf16, 5242880x1x40960x320>, !migraphx.shaped<2x32x10x128x128xf16, 5242880x163840x16384x128x1>
  }
}

Which comes from this migraphx module:

y1.0 = @param:y1.0 -> half_type, {320, 4, 3, 3}, {36, 1, 12, 4}
y0.0 = @param:y0.0 -> half_type, {2, 4, 128, 128}, {65536, 1, 512, 4}
x2.0 = @param:x2.0 -> half_type, {2, 32, 10, 128, 128}, {0, 10, 1, 0, 0}
@3 = convolution[padding={1, 1, 1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](y0.0,y1.0) -> half_type, {2, 320, 128, 128}, {5242880, 1, 40960, 320}
@4 = reshape[dims={2, 32, 10, 128, 128}](@3) -> half_type, {2, 32, 10, 128, 128}, {5242880, 163840, 16384, 128, 1}
@5 = add(@4,x2.0) -> half_type, {2, 32, 10, 128, 128}, {5242880, 163840, 16384, 128, 1}
@6 = @return(@3,@5)

The backend output shows:

#map = affine_map<(d0, d1) -> (d0 + d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (0, d1)>
#map3 = affine_map<(d0, d1) -> (d0 * 320 + d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0 * 2 + d1, d2)>
#map5 = affine_map<(d0, d1, d2) -> (d0 * 320 + d1, d2)>
#map6 = affine_map<(d0, d1) -> (0, d0, d1)>
#map7 = affine_map<(d0, d1) -> (d0, d1)>
#map8 = affine_map<(d0) -> (d0 floordiv 1280, d0 mod 1280)>
#transform_map = #rock.transform_map<#map by [<Unmerge{1280, 1} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>] bounds = [1280, 1] -> [1280]>
#transform_map1 = #rock.transform_map<#map1 by [<PassThrough ["dim1", "dim0"] at [0, 1] -> ["dim1", "dim0"] at [1, 0]>] bounds = [1, 1280] -> [1280, 1]>
#transform_map2 = #rock.transform_map<#map2 by [<Broadcast{1} ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>] bounds = [2, 1280] -> [1, 1280]>
#transform_map3 = #rock.transform_map<#map3 by [<Unmerge{1280, 320} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>] bounds = [1280, 320] -> [409600]>
#transform_map4 = #rock.transform_map<#map1 by [<PassThrough ["dim1", "dim0"] at [0, 1] -> ["dim1", "dim0"] at [1, 0]>] bounds = [320, 1280] -> [1280, 320]>
#transform_map5 = #rock.transform_map<#map3 by [<Unmerge{1, 320} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>] bounds = [1, 320] -> [320]>
#transform_map6 = #rock.transform_map<#map2 by [<Broadcast{1} ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>] bounds = [2, 320] -> [1, 320]>
#transform_map7 = #rock.transform_map<#map4 by [<Unmerge{1, 2} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>] bounds = [1, 2, 320] -> [2, 320]>
#transform_map8 = #rock.transform_map<#map5 by [<Unmerge{1, 320} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>] bounds = [1, 320, 1280] -> [320, 1280]>
#transform_map9 = #rock.transform_map<#map6 by [<ConstDim{0, 1} [] at [] -> ["g"] at [0]>, <PassThrough ["d0", "d1"] at [0, 1] -> ["d0", "d1"] at [1, 2]>] bounds = [2, 320] -> [1, 2, 320]>
#transform_map10 = #rock.transform_map<#map6 by [<Merge{1, 1280} ["gd1"] at [1] -> ["g", "d1"] at [0, 2]>, <PassThrough ["d0"] at [0] -> ["d0"] at [1]>] bounds = [320, 1280] -> [1, 320, 1280]>
#transform_map11 = #rock.transform_map<#map6 by [<Merge{1, 1280} ["gd1"] at [1] -> ["g", "d1"] at [0, 2]>, <PassThrough ["d0"] at [0] -> ["d0"] at [1]>] bounds = [2, 1280] -> [1, 2, 1280]>
#transform_map12 = #rock.transform_map<#map6 by [<Merge{1, 2} ["dim0"] at [0] -> ["col0", "col1"] at [0, 1]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [2]>] bounds = [2, 1280] -> [1, 2, 1280]>
#transform_map13 = #rock.transform_map<#map8 by [<Merge{2, 1280} ["dim0"] at [0] -> ["col0", "col1"] at [0, 1]>] bounds = [2560] -> [2, 1280]>
module {
  func.func @mlir_dot_add_sigmoid_mul(%arg0: memref<1280xf16>, %arg1: memref<320xf16>, %arg2: memref<409600xf16>, %arg3: memref<2560xf16>) attributes {arch = "gfx942:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 304 : i64} {
    %cst = arith.constant 1.000000e+00 : f16
    %0 = rock.transform %arg0 by #transform_map : memref<1280xf16> to memref<1280x1xf16>
    %1 = rock.transform %0 by #transform_map1 : memref<1280x1xf16> to memref<1x1280xf16>
    %2 = rock.transform %1 by #transform_map2 : memref<1x1280xf16> to memref<2x1280xf16>
    %3 = rock.transform %arg2 by #transform_map3 : memref<409600xf16> to memref<1280x320xf16>
    %4 = rock.transform %3 by #transform_map4 : memref<1280x320xf16> to memref<320x1280xf16>
    %5 = rock.transform %arg1 by #transform_map5 : memref<320xf16> to memref<1x320xf16>
    %6 = rock.transform %5 by #transform_map6 : memref<1x320xf16> to memref<2x320xf16>
    %7 = rock.transform %6 by #transform_map7 : memref<2x320xf16> to memref<1x2x320xf16>
    %8 = rock.transform %4 by #transform_map8 : memref<320x1280xf16> to memref<1x320x1280xf16>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x1280xf16>
    %9 = rock.transform %7 by #transform_map9 : memref<1x2x320xf16> to memref<2x320xf16>
    %10 = rock.transform %8 by #transform_map10 : memref<1x320x1280xf16> to memref<320x1280xf16>
    %11 = rock.transform %alloc by #transform_map11 : memref<1x2x1280xf16> to memref<2x1280xf16>
    rock.gemm %11 = %9 * %10 features =  mfma|dot|atomic_add storeMethod =  set {arch = "gfx942:sramecc+:xnack-", numCU = 304 : i32} : memref<2x1280xf16> = memref<2x320xf16> * memref<320x1280xf16>
    %12 = rock.transform %alloc by #transform_map12 : memref<1x2x1280xf16> to memref<2x1280xf16>
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x1280xf16>
    linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%12, %2 : memref<2x1280xf16>, memref<2x1280xf16>) outs(%alloc_0 : memref<2x1280xf16>) {
    ^bb0(%in: f16, %in_1: f16, %out: f16):
      %14 = arith.addf %in, %in_1 : f16
      %15 = arith.negf %14 : f16
      %16 = math.exp %15 : f16
      %17 = arith.addf %16, %cst : f16
      %18 = arith.divf %cst, %17 : f16
      %19 = arith.mulf %14, %18 : f16
      linalg.yield %19 : f16
    }
    %13 = rock.transform %alloc_0 by #transform_map13 : memref<2x1280xf16> to memref<2560xf16>
    memref.copy %13, %arg3 : memref<2560xf16> to memref<2560xf16>
    return
  }
}
krzysz00 commented 3 weeks ago

@pfultz2 I can't reproduce the issue - is this the right input?