iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.47k stars 548 forks source link

[codegen] softmax nans #17670

Open dan-garvey opened 1 week ago

dan-garvey commented 1 week ago

What happened?

func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178x1178xf32> {
  %c0 = arith.constant 0 : index
  %0 = tensor.empty() : tensor<2x24x1178x1178xf32>
  %1 = linalg.softmax dimension(3) ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%0 : tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178x1178xf32>
  return %1 : tensor<2x24x1178x1178xf32>
}

compile command:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-target-cpu-features=host 42.mlir -o 42.vmfb

input npy https://sharkblobs.blob.core.windows.net/dan/42_inputs.npy

output npy (for comparison) https://sharkblobs.blob.core.windows.net/dan/42_out.npy

iree-run-module --module=42.vmfb --function=softmax --input=@42_inputs.npy --output=@42_out_repro.npy

dan-garvey commented 1 week ago

@hanhanW identified [0, 18, 63, *] and [0, 18, 1389, *] are NANs

hanhanW commented 1 week ago

@pashu123 please help the further triaging. We dumped the inputs and outputs and verified that there are NANs.

pashu123 commented 1 week ago

On further debugging, the problem is with max calculation. Smaller repro:

func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178xf32> {
  %4 = tensor.empty() : tensor<2x24x1178xf32>
  %cst = arith.constant -3.40282347E+38 : f32
  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x24x1178xf32>) -> tensor<2x24x1178xf32>

  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%5 : tensor<2x24x1178xf32>) {
  ^bb0(%in: f32, %out: f32):
    %10 = arith.maximumf %in, %out : f32
    linalg.yield %10 : f32
  } -> tensor<2x24x1178xf32>

  return %6 : tensor<2x24x1178xf32>
}

https://gist.github.com/pashu123/83ca1f519aa39f1ce7a035122bbb7e54 (Compile and run commands are same as above)

I have created a Python script to debug: https://gist.github.com/pashu123/898636a138e41e1db2443acd1248d6d4

The output of Python script:

Mismatch at index (np.int64(0), np.int64(2), np.int64(1)): golden=-1.6139899492263794, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(9)): golden=-1.1718499660491943, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(10)): golden=-1.594499945640564, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(11)): golden=-1.9860199689865112, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(18)): golden=-1.1132500171661377, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(19)): golden=-2.1459200382232666, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(20)): golden=-1.3908900022506714, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(21)): golden=-1.2039200067520142, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(23)): golden=-3.720489978790283, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(24)): golden=-3.0760700702667236, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(25)): golden=-3.9601500034332275, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(26)): golden=-2.8110198974609375, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(27)): golden=-1.5647300481796265, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(29)): golden=-1.171970009803772, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(30)): golden=-2.9511098861694336, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(31)): golden=-1.1302900314331055, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(32)): golden=-3.8724400997161865, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(33)): golden=-1.8330700397491455, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(34)): golden=-1.1605299711227417, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(35)): golden=-5.191100120544434, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(36)): golden=-3.998159885406494, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(48)): golden=-2.524359941482544, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(54)): golden=-1.4726300239562988, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(55)): golden=-6.302299976348877, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(56)): golden=-1.0678900480270386, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(62)): golden=-3.644969940185547, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(63)): golden=-4.302030086517334, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(65)): golden=-1.2450499534606934, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(66)): golden=-2.546420097351074, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(70)): golden=-1.760390043258667, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(80)): golden=-1.1018799543380737, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(84)): golden=-2.6196000576019287, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(85)): golden=-1.4363700151443481, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(92)): golden=-1.8270699977874756, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(93)): golden=-5.119679927825928, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(94)): golden=-3.4443399906158447, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(95)): golden=-1.8535699844360352, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(98)): golden=-1.696810007095337, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(99)): golden=-2.281130075454712, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(100)): golden=-2.694159984588623, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(101)): golden=-3.200939893722534, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(102)): golden=-4.250319957733154, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(103)): golden=-2.6362600326538086, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(108)): golden=-1.3708399534225464, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(115)): golden=-1.9866199493408203, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(118)): golden=-2.3564600944519043, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(122)): golden=-4.689330101013184, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(123)): golden=-3.47625994682312, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(124)): golden=-2.152790069580078, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(125)): golden=-1.2989599704742432, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(127)): golden=-5.363550186157227, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(128)): golden=-4.256410121917725, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(130)): golden=-2.7768800258636475, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(134)): golden=-1.7649099826812744, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(135)): golden=-3.982069969177246, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(136)): golden=-6.1743998527526855, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(137)): golden=-6.286499977111816, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(138)): golden=-2.8284900188446045, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(139)): golden=-5.993460178375244, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(140)): golden=-3.436150074005127, iree=-0.0
Mismatch at index (np.int64(0), np.int64(2), np.int64(144)): golden=-2.254849910736084, iree=-0.0
Mismatch at index (np.int64(0), np.int64(4), np.int64(58)): golden=-1.4180999994277954, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(1)): golden=-19.961200714111328, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(2)): golden=-25.51609992980957, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(3)): golden=-6.272930145263672, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(4)): golden=-9.712470054626465, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(6)): golden=-10.295499801635742, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(7)): golden=-16.54210090637207, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(8)): golden=-39.7671012878418, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(9)): golden=-22.47920036315918, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(10)): golden=-23.77739906311035, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(11)): golden=-40.10390090942383, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(12)): golden=-10.307700157165527, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(13)): golden=-8.724579811096191, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(14)): golden=-3.1235899925231934, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(15)): golden=-15.26159954071045, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(16)): golden=-8.746410369873047, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(17)): golden=-9.033740043640137, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(18)): golden=-36.70589828491211, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(19)): golden=-41.16350173950195, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(20)): golden=-38.764198303222656, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(21)): golden=-20.7450008392334, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(22)): golden=-14.468999862670898, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(23)): golden=-19.56329917907715, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(24)): golden=-17.083499908447266, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(25)): golden=-20.79840087890625, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(26)): golden=-11.901700019836426, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(27)): golden=-21.383699417114258, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(28)): golden=-17.52400016784668, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(29)): golden=-16.292200088500977, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(30)): golden=-15.337599754333496, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(31)): golden=-14.481499671936035, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(32)): golden=-21.077600479125977, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(33)): golden=-26.247299194335938, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(34)): golden=-31.76959991455078, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(35)): golden=-27.92840003967285, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(36)): golden=-10.960000038146973, iree=-0.0
Mismatch at index (np.int64(0), np.int64(6), np.int64(37)): golden=-10.226499557495117, iree=-0.0

It looks like iree's output gets stuck at -0.0.

using https://mlir.llvm.org/docs/Dialects/ArithOps/#arithmaxnumf-arithmaxnumfop i.e.,

func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178xf32> {
  %4 = tensor.empty() : tensor<2x24x1178xf32>
  %cst = arith.constant -3.40282347E+38 : f32
  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x24x1178xf32>) -> tensor<2x24x1178xf32>

  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%5 : tensor<2x24x1178xf32>) {
  ^bb0(%in: f32, %out: f32):
    %10 = arith.maxnumf %in, %out : f32
    linalg.yield %10 : f32
  } -> tensor<2x24x1178xf32>

  return %6 : tensor<2x24x1178xf32>
}

Solves the problem. Meanwhile, I am reading the documentation. It's not clear to me why it happens 😆.

Cherry-pick: https://github.com/pashu123/llvm-project/tree/fyi_soft (verified)

hanhanW commented 1 week ago

Some related read: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671

I recall the time that we split the min(max) to minimum/minnum(maximum/maxnum). We could miss it in softmax because it was not on my radar.