openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.59k stars 403 forks source link

Pallas/Triton segfault on H100 #17356

Open jaro-sevcik opened 4 days ago

jaro-sevcik commented 4 days ago

After the commit https://github.com/openxla/xla/commit/cb304cf640723620d9a67173e6bb650e34217169, JAX crashes in Triton on H100 with the following repro:

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def mha_forward_kernel(
    q_ref,
    k_ref,
    v_ref,
    o_ref,  # Output
):
    q = pl.load(q_ref, (pl.dslice(0, 64), pl.dslice(None)))
    k = pl.load(k_ref, (pl.dslice(None), pl.dslice(0, 64)))
    qk = pl.dot(q, k).astype(jnp.bfloat16)
    v = pl.load(v_ref, (pl.dslice(0, 64), pl.dslice(0, 128)))
    o = pl.dot(qk, v).astype(jnp.bfloat16)
    pl.store(o_ref, (pl.dslice(0, 64), pl.dslice(None)), o)

q = jnp.zeros((64, 128), dtype=jnp.bfloat16)
k = jnp.zeros((128, 64), dtype=jnp.bfloat16)
v = jnp.zeros((64, 128), dtype=jnp.bfloat16)
pl.pallas_call(
        mha_forward_kernel,
        grid=(1,),
        in_specs=[
            pl.BlockSpec(lambda _: (0, 0), (64, 128)),
            pl.BlockSpec(lambda _: (0, 0), (128, 64)),
            pl.BlockSpec(lambda _: (0, 0), (64, 128)),
        ],
        out_specs=pl.BlockSpec(lambda _: (0, 0), (64, 128)),
        compiler_params=dict(triton=dict(num_warps=8)),
        out_shape=jax.ShapeDtypeStruct(shape=(64, 128), dtype=q.dtype),
        name="mha_forward",
    )(q, k, v)

The stack trace:

#0  mlir::detail::OperandStorage::OperandStorage
#1  mlir::Operation::create
#2  mlir::Operation::create
#3  mlir::Operation::create
#4  mlir::OpBuilder::create
#5  mlir::LLVM::InsertElementOp mlir::OpBuilder::create<mlir::LLVM::InsertElementOp, mlir::Type&, mlir::Value&, mlir::Value&, mlir::Value>
#6  loadReg
#7  convertDot
#8  convertWGMMA
#9  (anonymous namespace)::WarpGroupDotOpConversion::matchAndRewrite
#10 mlir::ConvertOpToLLVMPattern<mlir::triton::nvidia_gpu::WarpGroupDotOp>::matchAndRewrite
#11 mlir::ConversionPattern::matchAndRewrite
#12 void llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite
#13 mlir::PatternApplicator::matchAndRewrite
#14 (anonymous namespace)::OperationLegalizer::legalize
#15 mlir::OperationConverter::convert
#16 mlir::OperationConverter::convertOperations
#17 mlir::applyPartialConversion
#18 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation
#19 mlir::detail::OpToOpPassAdaptor::run
#20 mlir::detail::OpToOpPassAdaptor::runPipeline
#21 mlir::PassManager::run
#22 xla::gpu::CompileTritonToLLVM
#23 std::_Function_handler<absl::lts_20230802::StatusOr<xla::gpu::KernelReuseCache::Entry> (), xla::gpu::IrEmitterUnnested::EmitTritonCustomCall
#24 xla::gpu::KernelReuseCache::GetWithStatus
#25 xla::gpu::IrEmitterUnnested::EmitTritonCustomCall
#26 xla::gpu::IrEmitterUnnested::EmitHloInstruction
...

Here is the JAX version:

>>> import jax; jax.print_environment_info()
jax:    0.4.34.dev20240918+988ed2bd7
jaxlib: 0.4.34.dev20240919
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.15.0-84-generic', version='#93-Ubuntu SMP Tue Sep 5 17:16:10 UTC 2023', machine='x86_64')

$ nvidia-smi
Thu Sep 19 08:36:26 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:45:00.0 Off |                    0 |
| N/A   25C    P0             65W /  700W |     534MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
jaro-sevcik commented 4 days ago

For completeness, here is the crashing HLO:

HloModule jit_wrapped, entry_computation_layout={(bf16[64,128]{1,0}, bf16[128,64]{1,0}, bf16[64,128]{1,0})->bf16[64,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}

ENTRY main.5 {
  Arg_0.1 = bf16[64,128]{1,0} parameter(0), metadata={op_name="args[0]"}
  Arg_1.2 = bf16[128,64]{1,0} parameter(1), metadata={op_name="args[1]"}
  Arg_2.3 = bf16[64,128]{1,0} parameter(2), metadata={op_name="args[2]"}
  ROOT custom-call.4 = bf16[64,128]{1,0} custom-call(Arg_0.1, Arg_1.2, Arg_2.3), custom_call_target="__gpu$xla.gpu.triton", operand_layout_constraints={bf16[64,128]{1,0}, bf16[128,64]{1,0}, bf16[64,128]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(wrapped)/jit(main)/pallas_call" source_file="//workspace/triton_crash_repro_wip.py" source_line=21}, backend_config={debug = false, grid_x = 1 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIR20.0.0git\00\017\07\01\05\09!\01\03\0F\03\17\13\17\1B\1F#'+/37;\05\09?CGK\03\A5g+\01e\07\0F\0F\0F\0F\0F\0F\13\13\0B\0B\0F\13\13\0F\0B\0F\0F\0B\0F\0B\0F\0B\0F\1B\0B\0F\0B\0B\0F\0F\13\0B\13\0F\0F\13\0F\13\0F\0F\0F\13\0F\13\0F\0B\0F\0F\13\05\03Y\01)\0F\1F\1F\1F\07\17\17\1F\1B\1B\07\1F\1F\1F\1F\0B\1B\1B\1F\1F\03\039\02\AA\04\1F\11\01\01\1D]_\1D\1F;\1D\1FE\1D\1FQ\11\01\05\11\01\02\04\11\01\02\02\05'\05)\1DAC#\01\01\01\03\0335\11\1F\00\05+\1D)+\1D)/\05-\13\09\01\05/\15K\17\051\15W\17\01\09\1B\1B\1B\1B\053\11\01\81\0D\1D\055\15=\17\1D\15?\17\13\17\01\057\17\13+\01\15G\17\1D\15I\17\13\19\01\1D\15M\17\13\1B\01\1D-+\15S\17\1D\15U\17\13\1D\01\1D\15Y\17\13\1F\01\1D-/\059\15a\17\1D\15c\17\13!\01#arith.overflow<none>\00\01\02\02\1B\05\02\02\02\04\01\1B\05\02\04\02\02\01\1B\05\02\02\02\04)\0B\1B\03\02\02\01\1B\03\02\04\01\1B\05\02\02\02\04\15\1B\05\02\02\05\01\1B\05\05\02\04\01\07\1B\05\02\04\02\02)\1B\05\02\02\02\02\09\1B\05\02\02\02\04\09\05\09))))\01\01\09\1B\05\02\04\05\01\1B\05\05\02\02\01\1B\05\02\04\02\02\15\1B\05\02\02\02\02\15!tt.ptr<bf16>\00\04\16\0F\05\01P\01\01\07\04\F2\0E\03\01\05\11P\01\03\07\04\C6\0E\03\06\02\FE\03\09QQQQ\00\13B\01\05\03\01\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05\0B\0F\19B\01\0B\03\01\1BF\01\09\03\01\05\0D\13\19B\01\05\03\01\19B\01\05\03\01\19B\01\0B\03\01\1BF\01\09\03\01\05\17\1B\19B\01\07\03\01\1BF\01\09\03\01\05\19\1F\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05#'\19B\01\0B\03\01\1BF\01\09\03\01\05%+\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05/3\19B\01\0B\03\01\1BF\01\09\03\01\0517\19B\07\05\03\01\03\06\07\03\03\03;\05B\07\0D\03\0B\07F\07\0F\03\11\03?\09\06\07\03\03\03A\03\06\07\03\03\03\11\1DF\07\09\03\03\05CE\19B\07\0B\03\01\03\06\07\03\03\03I\1BF\07\09\03\03\05GK\1DF\07\09\03\03\05=M\05B\07\11\03\0D\07F\07\05\03\13\03Q\09\06\07\03\03\03S\03\06\07\03\03\03\15\1DF\07\09\03\03\05UW\19B\07\0F\03\01\03\06\07\03\03\03[\1BF\07\09\03\03\05Y]\1DF\07\09\03\03\05O_\03\06\07\03\07\03\01\0B\06\07\03\07\05ca\0DF\07\13\03\0F\03e\19B\09\05\03\01\03\06\09\03\05\03i\05B\09\11\03\0D\07F\09\0F\03!\03m\09\06\09\03\05\03o\03\06\09\03\05\03\1D\1DF\09\09\03\05\05qs\19B\09\07\03\01\03\06\09\03\05\03w\1BF\09\09\03\05\05uy\1DF\09\09\03\05\05k{\05B\09\0D\03\0B\07F\09\05\03#\03\7F\09\06\09\03\05\03\81\03\06\09\03\05\03!\1DF\09\09\03\05\05\83\85\19B\09\0F\03\01\03\06\09\03\05\03\89\1BF\09\09\03\05\05\87\8B\1DF\09\09\03\05\05}\8D\03\06\09\03\17\03\03\0B\06\09\03\17\05\91\8F\0DF\09\13\03%\03\93\19B!\15\03\09\03\06!\03\19\03\97\0FF!\17\03\19\07g\95\99\1FFO\19\03'\03\9B\19B\0B\05\03\01\03\06\0B\03\03\03\9F\05B\0B\0D\03\0B\07F\0B\0F\03\11\03\A3\09\06\0B\03\03\03\A5\03\06\0B\03\03\03)\1DF\0B\09\03\03\05\A7\A9\19B\0B\0B\03\01\03\06\0B\03\03\03\AD\1BF\0B\09\03\03\05\AB\AF\1DF\0B\09\03\03\05\A1\B1\05B\0B\11\03\0D\07F\0B\05\03\13\03\B5\09\06\0B\03\03\03\B7\03\06\0B\03\03\03-\1DF\0B\09\03\03\05\B9\BB\19B\0B\0F\03\01\03\06\0B\03\03\03\BF\1BF\0B\09\03\03\05\BD\C1\1DF\0B\09\03\03\05\B3\C3\03\06\0B\03\07\03\05\0B\06\0B\03\07\05\C7\C5\0DF\0B\13\03\0F\03\C9\19B#\15\03\09\03\06#\03\1B\03\CD\0FF#\17\03\1B\07\9D\CB\CF\1FF[\19\03\0F\03\D1\19B\05\05\03\01\03\06\05\03\03\03\D5\05B\05\0D\03\0B\07F\05\0F\03\11\03\D9\09\06\05\03\03\03\DB\03\06\05\03\03\035\1DF\05\09\03\03\05\DD\DF\19B\05\0B\03\01\03\06\05\03\03\03\E3\1BF\05\09\03\03\05\E1\E5\1DF\05\09\03\03\05\D7\E7\05B\05\11\03\0D\07F\05\05\03\13\03\EB\09\06\05\03\03\03\ED\03\06\05\03\03\039\1DF\05\09\03\03\05\EF\F1\19B\05\0F\03\01\03\06\05\03\03\03\F5\1BF\05\09\03\03\05\F3\F7\1DF\05\09\03\03\05\E9\F9\03\06\05\03\07\03\07\0B\06\05\03\07\05\FD\FB\0DF\05\13\03\0F\03\FF\15D\05\1B\05\FF\D3\17\00\01\06\03\01\05\01\00*\05;\1B\13\0F!-\1B\19\1B'M\0F\0B\0B\13\0F\0D\1F\0B\09\0B\0F\15\19\17\0D\0F\0D\07\11builtin\00tt\00arith\00module\00splat\00make_range\00expand_dims\00broadcast\00addptr\00load\00dot\00func\00get_program_id\00store\00return\00constant\00muli\00addi\00truncf\00//workspace/triton_crash_repro_wip.py\00mha_forward_kernel\00/masked_load\00mha_forward\00/dot_general\00/convert_element_type\00tt.divisibility\00public\00<module>\00/masked_swap\00\08_\1D\05K\01\0Bc7\01%s\03\03\03\11\03\CB\03\0F\05\11\03\03\0D\05\0F\03\113\1B\1B;\01\07\01\03\03'\05\07\07\05\01\01\073\1B\1B", name = "mha_forward", num_stages = 3 : i32, num_warps = 8 : i32}
}
cheshire commented 3 days ago

Seems like a Triton crash, but could you provide segfault with asan?

jaro-sevcik commented 5 hours ago

I have not run with ASAN, but debug version fails on OOB access:

... libc ...
#4  __assert_fail () from /lib/x86_64-linux-gnu/libc.so.6
#5  llvm::SmallVectorTemplateCommon<mlir::Value, void>::operator[] (this=0x7ffffffe2230, idx=16) at external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:295
#6  loadReg (rewriter=..., loc=..., elements=..., startIndex=16, numElements=8, insertBefore=0x555557ae3910) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:292
#7  convertDot (typeConverter=0x7ffffffe4a38, rewriter=..., loc=..., op=0x5555579280d0, a=..., b=..., c=..., d=..., useCOperand=..., loadedA=..., loadedB=..., loadedC=..., allowTF32=true,
    needsPartialAccumulator=false, maxNumImpreciseAcc=0, sync=true, thread=...) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:447
#8  convertWGMMA (op=..., adaptor=..., typeConverter=0x7ffffffe4a38, rewriter=..., thread=...) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:511
#9  (anonymous namespace)::WarpGroupDotOpConversion::matchAndRewrite (this=0x5555579d6600, op=..., adaptor=..., rewriter=...)
    at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp:95
#10 mlir::ConvertOpToLLVMPattern<mlir::triton::nvidia_gpu::WarpGroupDotOp>::matchAndRewrite (this=0x5555579d6600, op=0x5555579280d0, operands=..., rewriter=...)
    at external/llvm-project/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h:165
...