Closed mariecwhite closed 9 months ago
Ran it with ASAN and didn't get any errors.
Error started happening after https://github.com/openxla/iree/commit/a3d75cc3061e57ebdaf1ae6c381e1db3ca4cc0de
Oof. Needless to say, we need to track this down. A couple of things to try:
-iree-consteval-jit-debug
: Dumps a bunch of information about what is being evaled. (or set env var IREE_COMPILER_DEBUG_CONSTEVAL=1)-iree-opt-const-eval=false
to keep the hoisting but disable constant evaluation. May tell us whether the constexpr hoisting or the actual JIT evaluation is causing the issue.So the accuracy errors still occur with --iree-opt-const-eval=false
. @hanhanW also suggested trying with --iree-consteval-jit-target-backend=vmvx
and accuracy errors still occur.
Here is the output with --iree-consteval-jit-debug
: const_eval.txt
Thank you - this is helpful as it eliminates a rather large variable: the consteval jitter is a big, complicated integration and if that is not the problem, then it must be the hoisting heuristic itself that is doing something that it shouldn't and generating a subtly incorrect program. That is a single pass with debug options to see what it is doing.
There are a couple of next steps that can be tried:
--mlir-print-ir-before=iree-util-hoist-into-globals
If I were doing this, I would first make a detailed visual inspection of the IR changes that were made and see if I could spot anything out of the ordinary. I would also:
--debug --debub-only=iree-util-hoist-into-globals
to get a bit more information.Basically, I don't have a strong theory, but that is the debugging steps I would go down. I will likely have time to get to this on Monday and would welcome any work to narrow it down if you have the time before then.
- Keep an eye out for subtle dependency issues that may have changed when transforming, especially for nested ops as those are really tricky to deal with. (Mahesh is independently working on a patch to restrict ConstEval to not work on nested ops, and it might be worth checking with him to see what prompted him to do that and whether there might be a relation)
That PR already landed https://github.com/openxla/iree/commit/eeb6e80d4ade331a4271f92ea776a7afcbdc9ecb . I did this cause I was trying to handle some cases where we have pre-formed flow.dispatch.region
operations and const-eval was hoisting out operations from within pre-formed dispatches. So I just disabled that.
Thanks for the steps, I'll give it crack.
Thank you for having a look. If you make progress but can't see the issue, feel free to share artifacts.
I compared the mlir output from running -iree-util-hoist-into-globals
between a working model (batch 32) and a failing one (batch 64). For the most part, the IRs are similar with 32 replaced with 64 and some different variable names.
There is a chunk of code that differs repeatedly, which looks to stem from %cst_57 = arith.constant dense<7.680000e+02> : tensor<64x512xf32>
(batch 64) vs %cst_57 = arith.constant dense<7.680000e+02> : tensor<2x32x512xf32>
(batch 32).
Batch 32 (working):
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map6 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map7 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map8 = affine_map<(d0, d1, d2) -> (d1)>
#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
%cst_56 = arith.constant dense<0.000000e+00> : tensor<32x512xf32>
%cst_57 = arith.constant dense<7.680000e+02> : tensor<2x32x512xf32>
%cst_62 = arith.constant 0.000000e+00 : f32
%23 = tensor.empty() : tensor<1x32x512x768xf32>
%24 = linalg.generic {indexing_maps = [#map6, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%21 : tensor<32x512x768xf32>) outs(%23 : tensor<1x32x512x768xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x512x768xf32>
%25 = linalg.generic {indexing_maps = [#map6, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%22 : tensor<32x512x768xf32>) outs(%23 : tensor<1x32x512x768xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x512x768xf32>
%26 = tensor.empty() : tensor<2x32x512x768xf32>
%inserted_slice = tensor.insert_slice %24 into %26[0, 0, 0, 0] [1, 32, 512, 768] [1, 1, 1, 1] : tensor<1x32x512x768xf32> into tensor<2x32x512x768xf32>
%inserted_slice_64 = tensor.insert_slice %25 into %inserted_slice[1, 0, 0, 0] [1, 32, 512, 768] [1, 1, 1, 1] : tensor<1x32x512x768xf32> into tensor<2x32x512x768xf32>
%27 = tensor.empty() : tensor<2x32x512xf32>
%28 = linalg.fill ins(%cst_62 : f32) outs(%27 : tensor<2x32x512xf32>) -> tensor<2x32x512xf32>
%29 = linalg.generic {indexing_maps = [#map5, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%inserted_slice_64 : tensor<2x32x512x768xf32>) outs(%28 : tensor<2x32x512xf32>) {
^bb0(%in: f32, %out: f32):
%940 = arith.addf %out, %in : f32
linalg.yield %940 : f32
} -> tensor<2x32x512xf32>
%30 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%29, %cst_57 : tensor<2x32x512xf32>, tensor<2x32x512xf32>) outs(%27 : tensor<2x32x512xf32>) {
^bb0(%in: f32, %in_482: f32, %out: f32):
%940 = arith.divf %in, %in_482 : f32
linalg.yield %940 : f32
} -> tensor<2x32x512xf32>
%extracted_slice = tensor.extract_slice %30[0, 0, 0] [1, 32, 512] [1, 1, 1] : tensor<2x32x512xf32> to tensor<1x32x512xf32>
%collapsed_65 = tensor.collapse_shape %extracted_slice [[0, 1], [2]] : tensor<1x32x512xf32> into tensor<32x512xf32>
%extracted_slice_66 = tensor.extract_slice %30[1, 0, 0] [1, 32, 512] [1, 1, 1] : tensor<2x32x512xf32> to tensor<1x32x512xf32>
%collapsed_67 = tensor.collapse_shape %extracted_slice_66 [[0, 1], [2]] : tensor<1x32x512xf32> into tensor<32x512xf32>
%31 = tensor.empty() : tensor<32x512xf32>
%32 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%collapsed_65 : tensor<32x512xf32>) outs(%31 : tensor<32x512xf32>) {
^bb0(%in: f32, %out: f32):
%940 = arith.mulf %in, %in : f32
linalg.yield %940 : f32
} -> tensor<32x512xf32>
%33 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%collapsed_67, %32 : tensor<32x512xf32>, tensor<32x512xf32>) outs(%31 : tensor<32x512xf32>) {
^bb0(%in: f32, %in_482: f32, %out: f32):
%940 = arith.subf %in, %in_482 : f32
linalg.yield %940 : f32
} -> tensor<32x512xf32>
%34 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%33, %cst_56 : tensor<32x512xf32>, tensor<32x512xf32>) outs(%31 : tensor<32x512xf32>) {
^bb0(%in: f32, %in_482: f32, %out: f32):
%940 = arith.maxf %in, %in_482 : f32
linalg.yield %940 : f32
} -> tensor<32x512xf32>
Batch 64 (not working):
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map7 = affine_map<(d0, d1, d2) -> (d1)>
#map8 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map9 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
%cst_56 = arith.constant dense<0.000000e+00> : tensor<64x512xf32>
%cst_57 = arith.constant dense<7.680000e+02> : tensor<64x512xf32>
%cst_62 = arith.constant 0.000000e+00 : f32
%23 = tensor.empty() : tensor<64x512xf32>
%24 = linalg.fill ins(%cst_62 : f32) outs(%23 : tensor<64x512xf32>) -> tensor<64x512xf32>
%25 = linalg.generic {indexing_maps = [#map1, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%22 : tensor<64x512x768xf32>) outs(%24 : tensor<64x512xf32>) {
^bb0(%in: f32, %out: f32):
%937 = arith.addf %out, %in : f32
linalg.yield %937 : f32
} -> tensor<64x512xf32>
%26 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%25, %cst_57 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%23 : tensor<64x512xf32>) {
^bb0(%in: f32, %in_333: f32, %out: f32):
%937 = arith.divf %in, %in_333 : f32
linalg.yield %937 : f32
} -> tensor<64x512xf32>
%27 = linalg.generic {indexing_maps = [#map1, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%21 : tensor<64x512x768xf32>) outs(%24 : tensor<64x512xf32>) {
^bb0(%in: f32, %out: f32):
%937 = arith.addf %out, %in : f32
linalg.yield %937 : f32
} -> tensor<64x512xf32>
%28 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%27, %cst_57 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%23 : tensor<64x512xf32>) {
^bb0(%in: f32, %in_333: f32, %out: f32):
%937 = arith.divf %in, %in_333 : f32
linalg.yield %937 : f32
} -> tensor<64x512xf32>
%29 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%28 : tensor<64x512xf32>) outs(%23 : tensor<64x512xf32>) {
^bb0(%in: f32, %out: f32):
%937 = arith.mulf %in, %in : f32
linalg.yield %937 : f32
} -> tensor<64x512xf32>
%30 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%26, %29 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%23 : tensor<64x512xf32>) {
^bb0(%in: f32, %in_333: f32, %out: f32):
%937 = arith.subf %in, %in_333 : f32
linalg.yield %937 : f32
} -> tensor<64x512xf32>
%31 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%30, %cst_56 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%23 : tensor<64x512xf32>) {
^bb0(%in: f32, %in_333: f32, %out: f32):
%937 = arith.maxf %in, %in_333 : f32
linalg.yield %937 : f32
} -> tensor<64x512xf32>
%expanded_64 = tensor.expand_shape %31 [[0], [1, 2]] : tensor<64x512xf32> into tensor<64x512x1xf32>
%32 = linalg.generic {indexing_maps = [#map3, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%28 : tensor<64x512xf32>) outs(%19 : tensor<64x512x768xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x512x768xf32>
These differences also appear in the IR before running the -iree-util-hoist-into-globals
pass.
Ok, that was going to be my question about the input. Basically did hoisting cause these differences or react to them. If you had a gist of the input to the batch-64 pass (perhaps with large elements elided), I wouldn't mind glancing.
It gives us something to pull on: basically, for whatever reason, the batch-64 case is a different program and that at least makes it in the realm of possibility that the hoisting pass is doing something there based on those differences. Not much but semi explains why it is different across this specific batch size break point.
The thing that the constant hoisting pass does is grab DAGs of things it belives to be constants and turns them into initializers, replacing them with a util.global_load
. We'd be looking for places in the program where this happened and then trying to deduce what went wrong.
Here are the IRs before and after the pass for batch 32 and 64. I replaced the large elements with dense<...>.
I'm seeing different errors for batch_size=64 now. It looks like the boundary values of IREE outputs are zeros.
all_close: False. shape: (64, 512, 50257)
a[0, 0, 0]: -33.17163848876953, b[0, 0, 0]: -33.17176055908203
a[0, 1, 0]: -84.1167984008789, b[0, 1, 0]: -84.11685180664062
a[0, 0, 1]: -32.558433532714844, b[0, 0, 1]: -32.55856704711914
// a lot of 0.0 v.s. X
...
[41, 400, 45958]: 0.0 != -104.0179672241211
[41, 400, 45959]: 0.0 != -102.06230163574219
[41, 400, 45960]: 0.0 != -104.23350524902344
[41, 400, 45961]: 0.0 != -99.41698455810547
[41, 400, 45962]: 0.0 != -106.20645141601562
[41, 400, 45963]: 0.0 != -104.71014404296875
[41, 400, 45964]: 0.0 != -94.92705535888672
[41, 400, 45965]: 0.0 != -101.50908660888672
[41, 400, 45966]: 0.0 != -96.18388366699219
[41, 400, 45967]: 0.0 != -101.66986083984375
[41, 400, 45968]: 0.0 != -101.44477844238281
[41, 400, 45969]: 0.0 != -93.89815521240234
[41, 400, 45970]: 0.0 != -102.68502044677734
[41, 400, 45971]: 0.0 != -102.12843322753906
[41, 400, 45972]: 0.0 != -103.64442443847656
...
The files I'm using are from gs://iree-model-artifacts/jax/jax_models_0.4.23_1706594181/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64/
.
To reproduce the issue:
Download artifacts:
gcloud storage cp -r gs://iree-model-artifacts/jax/jax_models_0.4.23_1706594181/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64/ ~/
cd ~/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64
tar -axvf inputs_npy.tgz
tar -axvf outputs_npy.tgz
Compile and run the model:
iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64/stablehlo.mlirbc -o /tmp/a.vmfb --iree-llvmcpu-enable-ukernels=all
export MODEL_DIR=/home/hanchung/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64
build/tools/iree-run-module --function=main --input=@${MODEL_DIR}/input_0.npy --input=@${MODEL_DIR}/input_1.npy --module=/tmp/a.vmfb --output=@${MODEL_DIR}/out.npy
Check the output, note that it takes some time to dump the actual_value and expect_value. The compare_npy.py file can be found at the gist
cd ~/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64
python compare_npy.py -a out.npy -b output_0.npy
Thank you!! I was able to reproduce locally following these steps. Thanks a lot for the great repro + tool. Taking over the investigation from here.
Good news / bad news time.
Good news: I do have a regression window. It's still wide and i'll bisect through it tomorrow, but it is a reproducible regression windows. Last known good = faff7e0cc5ff9e25303bb7e9f6ac87ee88c5bef4
, first known bad = cb020fee27c9b7b23fad33b19e4c92444a0ae797
= today's main
branch.
Bad news: This regression window is inconsistent with the date this Issue was filed on August 8. My last known-good faff7e0cc5ff9e25303bb7e9f6ac87ee88c5bef4
is from August 11. I tried a few older commits from late July to early August, none reproduces.
So there's a bit of a mystery here about how this reproduced /then/ but at least for the purpose of debugging the bug we have /now/ I have a regression window to bisect.
Bisection points to https://github.com/openxla/iree/pull/14016 @benvanik
(HEAD) ~/iree git bisect good
63381a8309eb6b9f4e0cac25c98bb7e63647ac2c is the first bad commit
commit 63381a8309eb6b9f4e0cac25c98bb7e63647ac2c
Author: Ben Vanik <ben.vanik@gmail.com>
Date: Thu Oct 19 11:09:49 2023 -0700
Switching external resources to be device-local only. (#14016)
This PR's description says
A temporary flag has been added to revert to the old mappable behavior with --iree-stream-external-resources-mappable=true.
But I tried both values of --iree-stream-external-resources-mappable
and it fails either way.
To confirm, I tried the previous commit and it works:
commit 87c968c70d0b587996c155053b9e84bf76f335af (HEAD)
Author: Scott Todd <scotttodd@google.com>
Date: Thu Oct 19 10:37:46 2023 -0700
Lint fix trailing space in debugging/releases.md. (#15239)
vmvx in the title is weird - this seems to be all CPU backends? (want to make sure I understand what's wrong)
vmvx in the title is weird - this seems to be all CPU backends? (want to make sure I understand what's wrong)
Yes, this is CPU backends. Let me update the title.
Ooooh the blamed PR #14016 is an anagram of this Issue #14601 !!!!! It was right under our nose all along!
Here are my steps-to-reproduce at earlier commits such as at #14016:
iree-opt
, process the file stablehlo.mlirbc
from Hanhan's steps to obtain stablehlo.mlir
. This allows reproducing across revisions where the bytecode format changed.stablehlo.mlir
. Here is my own adaptation of these steps, also minimizing compiler flags (some were unnecessary):ninja
tools/iree-compile \
--iree-hal-target-backends=llvm-cpu \
${MODEL_DIR}/stablehlo.mlir \
-o /tmp/a.vmfb
tools/iree-run-module \
--module=/tmp/a.vmfb \
--function=main \
--input=@${MODEL_DIR}/input_0.npy \
--input=@${MODEL_DIR}/input_1.npy \
--output=@${MODEL_DIR}/out.npy
python ~/testing/compare_npy.py -a $MODEL_DIR/out.npy -b $MODEL_DIR/output_0.npy
thanks for the updated steps! mind sharing your .mlir? I'll try to take a look at this today
(HEAD) ~/iree ls -l ~/testing/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64/stablehlo.mlir
-rw-rw-r-- 1 benoit benoit 995715861 Feb 5 14:47 /home/benoit/testing/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64/stablehlo.mlir
1 GB ... big enough that I think it's more reasonable to download the gcloud archive linked by Hanhan above, which gives you the .mlirbc
, and just run it through current iree-opt
to obtain the .mlir
, WDYT?
👍
According to @benvanik it's merely bypassing an existing bug in buffer transfers, which IIUC means that #14016 (of which it is a partial revert) merely exposed that preexisting bug.
is the model download still valid? I get a 13mb file that has a truncated tar in it, not whatever 1gb thing benoit mentions?
I see some async bugs in the tooling transfer code that may cause issues in some cases but not with how the tools call it - I can dramatically simplify the helper because it's only used in this one place today. I suspect this will only repro with local-task and local-sync won't have an issue and that would confirm it (though there still may be other issues). @bjacob / @hanhanW can you make sure the model download works?
@benvanik
is the model download still valid? I get a 13mb file that has a truncated tar in it, not whatever 1gb thing benoit mentions?
As mentioned above, the original link in the Issue description is broken in the way you describe. Use instead the gcloud command Hanhan provided in https://github.com/openxla/iree/issues/14601#issuecomment-1927525221
local-sync works, so definitely async transfer issue. will fix.
heh heh found it just by looking at the npy file in a hex editor - the first zero value is at [41, 373, 1019], which is at byte offset 4294967296 = (41(51250257)+37350257+1019)4, or 100000000h. file contents go zero exactly 80h from there: because the numpy header is 80h:
so, it's possible there's a uint32_t used somewhere for a size in the transfer flow that shouldn't be - will debug through and find it.
(or it's the fwrite 4gb limit again)
Fix is in #16364.
RE what this test model is doing: returning 6GB of outputs is crazy, and I'd not use this GPT2 as a benchmark for anything in its current state. A good chunk of time is going to be churning through that memory.
Thanks for the fix anyway! And thanks for the analysis, that's interesting. I wonder why this "GPT2" model is returning 6 GB outputs. Could that be that it's returning internal state for the next token iteration? Is 6 GB a reasonable order of magnitude for even that? Or is it truncated in some other, more arbitrary way? @mariecwhite
yeah I have no idea what the 6GB in a single tensor is, but the whole model with all the constants is only 700MB so blowing that up to 6GB feels kinda crazy. Can't imagine that's working as intended :)
@MaheshRavishankar that probably means we're back to not being able to use that model in next week's talk :-)
You can use it, just with the caveat that between the extra 6GB transient/result, 400k workgroup mmt4d ops, and 50% utilization unpacks it's got a lot of headroom for improvement :)
Well, fact that we still have a lot of room for improvement is good. The comparison is what it is... I am not sure why we can't use it as is now
The problem is, now that we know how much room for improvement there is, that means that where we were was not good, and the baseline we were comparing against was not good either. We should focus on numbers that are absolutely good, not just good relatively to a poor baseline, and choosing poor baselines to compare against is a smell in itself for any well-informed audience.
I have no idea how good the other baselines are...
Thanks for the fix anyway! And thanks for the analysis, that's interesting. I wonder why this "GPT2" model is returning 6 GB outputs. Could that be that it's returning internal state for the next token iteration? Is 6 GB a reasonable order of magnitude for even that? Or is it truncated in some other, more arbitrary way? @mariecwhite
This is expected. It's returning the logits (before softmax and post-processing like argmax, top-k, etc.). The size is batch_size x input_sequence_length x vocab_size
. 64 x 512 x 50k = 1.6 Gb x 4 (float) = 6.6 Gb
.
What happened?
For both CPU default flags and data-tiling+ukernel paths, we see an accuracy error on GPT2 for batch sizes 48 and 64 (possibly higher but haven't tried). Accuracy is correct for batch sizes 1, 2, 8, 16 and 32.
Also seeing the same behavior in VMVX backend.
Steps to reproduce your issue
Download and extract https://storage.googleapis.com/iree-shared-files/jax_models.tar.xz.
Should see output like below, indicating the accuracy is correct:
Output should look like below, indicating accuracy error.
The same behavior is seen when using data tiling + ukernel compiler flags.
What component(s) does this issue relate to?
Compiler
Version information
iree-compiler release 20230804.603
Additional context
No response