tensorflow / mlir

"Multi-Level Intermediate Representation" Compiler Infrastructure
1.74k stars 259 forks source link

Bad affine loop fusion #222

Closed liwchang closed 5 years ago

liwchang commented 5 years ago

Build 11/04, Windows. A trivial matrix multiplication and add can go wrong in affine loop fusion. Two test cases, mul_add_0, and mul_add_1. mul_add_0 generated a completely wrong code, while mul_add_1 just didn't fuse at all, though it should.

Using mlir-opt.exe %s -affine-loop-fusion -split-input-file

func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) {
  %cst = constant 0.000000e+00 : f32
  %0 = alloc() : memref<3x3xf32>
  affine.for %arg4 = 0 to 3 {
    affine.for %arg5 = 0 to 3 {
      affine.store %cst, %0[%arg4, %arg5] : memref<3x3xf32>
    }
  }
  affine.for %arg4 = 0 to 3 {
    affine.for %arg5 = 0 to 3 {
      affine.for %arg6 = 0 to 4 {
        %1 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32>
        %2 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32>
        %3 = mulf %2, %1 : f32
        %4 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
        %5 = addf %4, %3 : f32
        affine.store %5, %0[%arg4, %arg5] : memref<3x3xf32>
      }
    }
  }
  affine.for %arg4 = 0 to 3 {
    affine.for %arg5 = 0 to 3 {
      %6 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32>
      %7 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
      %8 = addf %7, %6 : f32
      affine.store %8, %arg3[%arg4, %arg5] : memref<3x3xf32>
    }
  }
  return
}
func @mul_add_1(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) {
  %cst = constant 0.000000e+00 : f32
  %0 = alloc() : memref<3x3xf32>
  affine.for %arg4 = 0 to 3 {
    affine.for %arg5 = 0 to 3 {
      affine.store %cst, %0[%arg4, %arg5] : memref<3x3xf32>
      affine.for %arg6 = 0 to 4 {
        %1 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32>
        %2 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32>
        %3 = mulf %2, %1 : f32
        %4 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
        %5 = addf %4, %3 : f32
        affine.store %5, %0[%arg4, %arg5] : memref<3x3xf32>
      }
    }
  }
  affine.for %arg4 = 0 to 3 {
    affine.for %arg5 = 0 to 3 {
      %6 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32>
      %7 = affine.load %0[%arg4, %arg5] : memref<3x3xf32>
      %8 = addf %7, %6 : f32
      affine.store %8, %arg3[%arg4, %arg5] : memref<3x3xf32>
    }
  }
  return
}
ftynse commented 5 years ago

It would be great if you could provide more details on what is "completely wrong".

With 6bb76f3, I got the following for the first example

#map0 = (d0, d1) -> (d0, d1)
#map1 = () -> (0)
#map2 = () -> (3)
#map3 = () -> (0, 0)
#map4 = () -> (4)

module {
  func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) {
    %0 = alloc() : memref<1x1xf32>
    %cst = constant 0.000000e+00 : f32
    %1 = alloc() : memref<3x3xf32>
    affine.for %arg4 = 0 to 3 {
      affine.for %arg5 = 0 to 3 {
        affine.store %cst, %1[%arg4, %arg5] : memref<3x3xf32>
      }
    }
    affine.for %arg4 = 0 to 3 {
      affine.for %arg5 = 0 to 3 {
        affine.for %arg6 = 0 to 4 {
          %5 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32>
          %6 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32>
          %7 = mulf %6, %5 : f32
          %8 = affine.load %0[0, 0] : memref<1x1xf32>
          %9 = addf %8, %7 : f32
          affine.store %9, %0[0, 0] : memref<1x1xf32>
        }
        %2 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32>
        %3 = affine.load %0[0, 0] : memref<1x1xf32>
        %4 = addf %3, %2 : f32
        affine.store %4, %arg3[%arg4, %arg5] : memref<3x3xf32>
      }
    }
    return
  }
}

This indeed does not look correct. In particular, there is a read from %0 that was never written to. I suppose this value exists as a temporary that the fusion pass intended to use but did not update properly. @bondhugula, @andydavis1 can you take a look?

Fusing or not is a heuristic-driven decision. I would not say if something should be fused or not. In general, no fusion is better than buggy fusion.

andydavis1 commented 5 years ago

Thanks for pointing this out. I'm investigating...

bondhugula commented 5 years ago

Thanks for the report. The output for the first case should have been:

#map0 = (d0, d1) -> (d0, d1)
#map1 = () -> (0)
#map2 = () -> (3)
#map3 = () -> (0, 0)
#map4 = () -> (4)

module {
  func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x3xf32>, %arg3: memref<3x3xf32>) {
    %0 = alloc() : memref<1x1xf32>
    %cst = constant 0.000000e+00 : f32
    affine.for %arg4 = 0 to 3 {
      affine.for %arg5 = 0 to 3 {
        affine.store %cst, %0[0, 0] : memref<1x1xf32>
        affine.for %arg6 = 0 to 4 {
          %5 = affine.load %arg1[%arg6, %arg5] : memref<4x3xf32>
          %6 = affine.load %arg0[%arg4, %arg6] : memref<3x4xf32>
          %7 = mulf %6, %5 : f32
          %8 = affine.load %0[0, 0] : memref<1x1xf32>
          %9 = addf %8, %7 : f32
          affine.store %9, %0[0, 0] : memref<1x1xf32>
        }
        %2 = affine.load %arg2[%arg4, %arg5] : memref<3x3xf32>
        %3 = affine.load %0[0, 0] : memref<1x1xf32>
        %4 = addf %3, %2 : f32
        affine.store %4, %arg3[%arg4, %arg5] : memref<3x3xf32>
      }
    }
    return
  }
}

This would eliminate the intermediate matrix %0 entirely (turning it into a single elt memref).

andydavis1 commented 5 years ago

I have a fix on the way.

andydavis1 commented 5 years ago

OK. submitted a fix. I'll close this bug for now. Please let me know if you see any more issues...