Open jaro-sevcik opened 4 days ago
For completeness, here is the crashing HLO:
HloModule jit_wrapped, entry_computation_layout={(bf16[64,128]{1,0}, bf16[128,64]{1,0}, bf16[64,128]{1,0})->bf16[64,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}
ENTRY main.5 {
Arg_0.1 = bf16[64,128]{1,0} parameter(0), metadata={op_name="args[0]"}
Arg_1.2 = bf16[128,64]{1,0} parameter(1), metadata={op_name="args[1]"}
Arg_2.3 = bf16[64,128]{1,0} parameter(2), metadata={op_name="args[2]"}
ROOT custom-call.4 = bf16[64,128]{1,0} custom-call(Arg_0.1, Arg_1.2, Arg_2.3), custom_call_target="__gpu$xla.gpu.triton", operand_layout_constraints={bf16[64,128]{1,0}, bf16[128,64]{1,0}, bf16[64,128]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(wrapped)/jit(main)/pallas_call" source_file="//workspace/triton_crash_repro_wip.py" source_line=21}, backend_config={debug = false, grid_x = 1 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIR20.0.0git\00\017\07\01\05\09!\01\03\0F\03\17\13\17\1B\1F#'+/37;\05\09?CGK\03\A5g+\01e\07\0F\0F\0F\0F\0F\0F\13\13\0B\0B\0F\13\13\0F\0B\0F\0F\0B\0F\0B\0F\0B\0F\1B\0B\0F\0B\0B\0F\0F\13\0B\13\0F\0F\13\0F\13\0F\0F\0F\13\0F\13\0F\0B\0F\0F\13\05\03Y\01)\0F\1F\1F\1F\07\17\17\1F\1B\1B\07\1F\1F\1F\1F\0B\1B\1B\1F\1F\03\039\02\AA\04\1F\11\01\01\1D]_\1D\1F;\1D\1FE\1D\1FQ\11\01\05\11\01\02\04\11\01\02\02\05'\05)\1DAC#\01\01\01\03\0335\11\1F\00\05+\1D)+\1D)/\05-\13\09\01\05/\15K\17\051\15W\17\01\09\1B\1B\1B\1B\053\11\01\81\0D\1D\055\15=\17\1D\15?\17\13\17\01\057\17\13+\01\15G\17\1D\15I\17\13\19\01\1D\15M\17\13\1B\01\1D-+\15S\17\1D\15U\17\13\1D\01\1D\15Y\17\13\1F\01\1D-/\059\15a\17\1D\15c\17\13!\01#arith.overflow<none>\00\01\02\02\1B\05\02\02\02\04\01\1B\05\02\04\02\02\01\1B\05\02\02\02\04)\0B\1B\03\02\02\01\1B\03\02\04\01\1B\05\02\02\02\04\15\1B\05\02\02\05\01\1B\05\05\02\04\01\07\1B\05\02\04\02\02)\1B\05\02\02\02\02\09\1B\05\02\02\02\04\09\05\09))))\01\01\09\1B\05\02\04\05\01\1B\05\05\02\02\01\1B\05\02\04\02\02\15\1B\05\02\02\02\02\15!tt.ptr<bf16>\00\04\16\0F\05\01P\01\01\07\04\F2\0E\03\01\05\11P\01\03\07\04\C6\0E\03\06\02\FE\03\09QQQQ\00\13B\01\05\03\01\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05\0B\0F\19B\01\0B\03\01\1BF\01\09\03\01\05\0D\13\19B\01\05\03\01\19B\01\05\03\01\19B\01\0B\03\01\1BF\01\09\03\01\05\17\1B\19B\01\07\03\01\1BF\01\09\03\01\05\19\1F\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05#'\19B\01\0B\03\01\1BF\01\09\03\01\05%+\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05/3\19B\01\0B\03\01\1BF\01\09\03\01\0517\19B\07\05\03\01\03\06\07\03\03\03;\05B\07\0D\03\0B\07F\07\0F\03\11\03?\09\06\07\03\03\03A\03\06\07\03\03\03\11\1DF\07\09\03\03\05CE\19B\07\0B\03\01\03\06\07\03\03\03I\1BF\07\09\03\03\05GK\1DF\07\09\03\03\05=M\05B\07\11\03\0D\07F\07\05\03\13\03Q\09\06\07\03\03\03S\03\06\07\03\03\03\15\1DF\07\09\03\03\05UW\19B\07\0F\03\01\03\06\07\03\03\03[\1BF\07\09\03\03\05Y]\1DF\07\09\03\03\05O_\03\06\07\03\07\03\01\0B\06\07\03\07\05ca\0DF\07\13\03\0F\03e\19B\09\05\03\01\03\06\09\03\05\03i\05B\09\11\03\0D\07F\09\0F\03!\03m\09\06\09\03\05\03o\03\06\09\03\05\03\1D\1DF\09\09\03\05\05qs\19B\09\07\03\01\03\06\09\03\05\03w\1BF\09\09\03\05\05uy\1DF\09\09\03\05\05k{\05B\09\0D\03\0B\07F\09\05\03#\03\7F\09\06\09\03\05\03\81\03\06\09\03\05\03!\1DF\09\09\03\05\05\83\85\19B\09\0F\03\01\03\06\09\03\05\03\89\1BF\09\09\03\05\05\87\8B\1DF\09\09\03\05\05}\8D\03\06\09\03\17\03\03\0B\06\09\03\17\05\91\8F\0DF\09\13\03%\03\93\19B!\15\03\09\03\06!\03\19\03\97\0FF!\17\03\19\07g\95\99\1FFO\19\03'\03\9B\19B\0B\05\03\01\03\06\0B\03\03\03\9F\05B\0B\0D\03\0B\07F\0B\0F\03\11\03\A3\09\06\0B\03\03\03\A5\03\06\0B\03\03\03)\1DF\0B\09\03\03\05\A7\A9\19B\0B\0B\03\01\03\06\0B\03\03\03\AD\1BF\0B\09\03\03\05\AB\AF\1DF\0B\09\03\03\05\A1\B1\05B\0B\11\03\0D\07F\0B\05\03\13\03\B5\09\06\0B\03\03\03\B7\03\06\0B\03\03\03-\1DF\0B\09\03\03\05\B9\BB\19B\0B\0F\03\01\03\06\0B\03\03\03\BF\1BF\0B\09\03\03\05\BD\C1\1DF\0B\09\03\03\05\B3\C3\03\06\0B\03\07\03\05\0B\06\0B\03\07\05\C7\C5\0DF\0B\13\03\0F\03\C9\19B#\15\03\09\03\06#\03\1B\03\CD\0FF#\17\03\1B\07\9D\CB\CF\1FF[\19\03\0F\03\D1\19B\05\05\03\01\03\06\05\03\03\03\D5\05B\05\0D\03\0B\07F\05\0F\03\11\03\D9\09\06\05\03\03\03\DB\03\06\05\03\03\035\1DF\05\09\03\03\05\DD\DF\19B\05\0B\03\01\03\06\05\03\03\03\E3\1BF\05\09\03\03\05\E1\E5\1DF\05\09\03\03\05\D7\E7\05B\05\11\03\0D\07F\05\05\03\13\03\EB\09\06\05\03\03\03\ED\03\06\05\03\03\039\1DF\05\09\03\03\05\EF\F1\19B\05\0F\03\01\03\06\05\03\03\03\F5\1BF\05\09\03\03\05\F3\F7\1DF\05\09\03\03\05\E9\F9\03\06\05\03\07\03\07\0B\06\05\03\07\05\FD\FB\0DF\05\13\03\0F\03\FF\15D\05\1B\05\FF\D3\17\00\01\06\03\01\05\01\00*\05;\1B\13\0F!-\1B\19\1B'M\0F\0B\0B\13\0F\0D\1F\0B\09\0B\0F\15\19\17\0D\0F\0D\07\11builtin\00tt\00arith\00module\00splat\00make_range\00expand_dims\00broadcast\00addptr\00load\00dot\00func\00get_program_id\00store\00return\00constant\00muli\00addi\00truncf\00//workspace/triton_crash_repro_wip.py\00mha_forward_kernel\00/masked_load\00mha_forward\00/dot_general\00/convert_element_type\00tt.divisibility\00public\00<module>\00/masked_swap\00\08_\1D\05K\01\0Bc7\01%s\03\03\03\11\03\CB\03\0F\05\11\03\03\0D\05\0F\03\113\1B\1B;\01\07\01\03\03'\05\07\07\05\01\01\073\1B\1B", name = "mha_forward", num_stages = 3 : i32, num_warps = 8 : i32}
}
Seems like a Triton crash, but could you provide segfault with asan?
I have not run with ASAN, but debug version fails on OOB access:
... libc ...
#4 __assert_fail () from /lib/x86_64-linux-gnu/libc.so.6
#5 llvm::SmallVectorTemplateCommon<mlir::Value, void>::operator[] (this=0x7ffffffe2230, idx=16) at external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:295
#6 loadReg (rewriter=..., loc=..., elements=..., startIndex=16, numElements=8, insertBefore=0x555557ae3910) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:292
#7 convertDot (typeConverter=0x7ffffffe4a38, rewriter=..., loc=..., op=0x5555579280d0, a=..., b=..., c=..., d=..., useCOperand=..., loadedA=..., loadedB=..., loadedC=..., allowTF32=true,
needsPartialAccumulator=false, maxNumImpreciseAcc=0, sync=true, thread=...) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:447
#8 convertWGMMA (op=..., adaptor=..., typeConverter=0x7ffffffe4a38, rewriter=..., thread=...) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:511
#9 (anonymous namespace)::WarpGroupDotOpConversion::matchAndRewrite (this=0x5555579d6600, op=..., adaptor=..., rewriter=...)
at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp:95
#10 mlir::ConvertOpToLLVMPattern<mlir::triton::nvidia_gpu::WarpGroupDotOp>::matchAndRewrite (this=0x5555579d6600, op=0x5555579280d0, operands=..., rewriter=...)
at external/llvm-project/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h:165
...
After the commit https://github.com/openxla/xla/commit/cb304cf640723620d9a67173e6bb650e34217169, JAX crashes in Triton on H100 with the following repro:
The stack trace:
Here is the JAX version: