intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
129 stars 37 forks source link

[PyTorch Upstream] The beheavior of Intel Triton tl.load with mask not align with CUDA Triton #1125

Closed etaf closed 4 months ago

etaf commented 4 months ago

The current stock Pytorch Inductor case test_torchinductor.py:test_nll_loss_backward_xpu failed on Intel GPU, and it blocked the CI test of all XPU related upstream PRs. The consolution is that the other value from masked tl.load is undefined on Intel GPU Triton but zeros in CUDA Triton. We've also created a issue in stock Pytorch: https://github.com/pytorch/pytorch/issues/126173

Details:

The error is caused by tl.device_assert in triton kernel. Inductor generated the same triton kernel for both CUDA and Intel GPU except for device specific API as follow:

@triton_heuristics.pointwise(
    size_hints=[8],
    filename=__file__,
    triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.27642', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_fp16': True, 'has_fp64': True, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51539607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '1.3'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_nll_loss_backward_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '121590420ce95e3dd78af161b1571fecb2ee847a144d53f7899eb15446278d5b', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): # in_ptr0 is a int64 tensor (0, 0, 0, 0, 0). 
    xnumel = 5
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    # for debug
    # tl.device_print("tmp0", tmp0)
    tmp1 = tl.full([1], -100, tl.int64)
    tmp2 = tmp0 != tmp1
    tmp3 = tl.full([1], 0, tl.int64)
    tmp4 = tl.where(tmp2, tmp0, tmp3)
    tmp5 = tmp4 + 5
    tmp6 = tmp4 < 0
    tmp7 = tl.where(tmp6, tmp5, tmp4)
    # for debug
     # tl.device_print("tmp7", tmp7)
    tl.device_assert((0 <= tmp7) & (tmp7 < 5), "[etaf] index out of bounds: 0 <= tmp7 < 5")  # assert fail
    tmp8 = -1.0
''', device_str='xpu')

The above kernel was launced with size_hints=[8] but only 5 elements need to be processed. In our case, the in_ptr0 is a int64 tensor with five 0 elements. So the tmp0 from masked tl.load got (0, 0, 0, 0, 0, ?, ?, ?) In Intel GPU. Like the following output:

pid (0, 0, 0) idx (0) tmp0: 0
pid (0, 0, 0) idx (1) tmp0: 0
pid (0, 0, 0) idx (2) tmp0: 0
pid (0, 0, 0) idx (3) tmp0: 0
pid (0, 0, 0) idx (4) tmp0: 0
pid (0, 0, 0) idx (5) tmp0: 18374686479671623680
pid (0, 0, 0) idx (6) tmp0: 18374686479688007680
pid (0, 0, 0) idx (7) tmp0: 18374686479688024560

Then the tl.device_assert line fail in work item idx 5, 6, 7 because of the undefined value.

The masked tl.load beheavior of CUDA Triton got the folllowing result:

pid (0, 0, 0) idx (0) tmp0: 0
pid (0, 0, 0) idx (1) tmp0: 0
pid (0, 0, 0) idx (2) tmp0: 0
pid (0, 0, 0) idx (3) tmp0: 0
pid (0, 0, 0) idx (4) tmp0: 0
pid (0, 0, 0) idx (5) tmp0: 0
pid (0, 0, 0) idx (6) tmp0: 0
pid (0, 0, 0) idx (7) tmp0: 0

The result shows CUDA Triton set the other element of masked tl.load as zeros, and we should align Intel GPU Triton with it.

We did some investigation you may refer: The beheavior of CUDA: https://github.com/triton-lang/triton/blob/d7c8b3d7890125f5fc1b9f046e3189baa2665be4/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp#L237C1-L241C8

Versions

latest triton llvm-target branch.

etaf commented 4 months ago

@riverliuintel @vlad-penkin This issue is seriously blocks XPU Pytorch upstream, please priortize.

etaf commented 4 months ago

Hi, @riverliuintel @vlad-penkin anyone is looking at this issue?

alexbaden commented 4 months ago

I am not able to reproduce this using the LTS Agama driver (803), commit https://github.com/intel/intel-xpu-backend-for-triton/commit/21bd53648dfb88d593f53aac085787951082b302 , and Data Center Max 1100.

$ python test/inductor/test_torchinductor.py -k test_nll_loss_backward                   
inline_call []
stats [('calls_captured', 4), ('unique_graphs', 2)]
inductor [('fxgraph_cache_miss', 1), ('fxgraph_cache_hit', 1)]
aot_autograd [('total', 2), ('ok', 2)]
.inline_call []
stats [('calls_captured', 8), ('unique_graphs', 4)]
aot_autograd [('total', 4), ('ok', 4)]
inductor [('intermediate_hooks', 4), ('fxgraph_cache_miss', 2), ('fxgraph_cache_hit', 2)]
.
----------------------------------------------------------------------
Ran 2 tests in 133.891s

OK

Is there a different way to invoke the test and see the expected failure?

etaf commented 4 months ago

Hi @alexbaden , I've send you an email to reproduce in our private machine with LTS Agama driver (803), please check the mail.