intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
100 stars 29 forks source link

[PyTorch Upstream] HuggingFace two model E2E training accuracy regression. #1255

Open etaf opened 1 month ago

etaf commented 1 month ago

After upgrade triton commit pin from b8c64f64c18d8cac598b3adb355c21e7439c21de( currently stock Pytorch in used) to 514e4cdf004278c82216364d1f8534b940cd4238 (2.4 release candidate), We found the two model accruacy error in the same Pytorch version (every thing is the same except for triton):

xpu,DebertaForMaskedLM,1,fail_accuracy,804,1,3,3,0,0,0
xpu,DebertaForQuestionAnswering,1,fail_accuracy,811,1,3,3,0,0,0

I've also tried the earliest 2.4 release candidate commit id: dd072259e1df5c0cd42034d90ec5f464f22e5ae3 and passed, hope this will reduce the scope of your regrssion searches.

To reproduce:

use the stock pytorch main branch plus this PR: https://github.com/pytorch/pytorch/pull/126516 Test script: https://github.com/intel/torch-xpu-ops/blob/main/.github/scripts/inductor_xpu_test.sh driver: 803.29 command:

bash -x inductor_xpu_test.sh huggingface  float32  training accuracy xpu 0 static 1 0 DebertaFor
MaskedLM
alexbaden commented 1 month ago

I was able to reproduce the failure with dd07225. However, b8c64f4 passed.

Bisecting produced the following result:

There are only 'skip'ped commits left to test.
The first bad commit could be any of:
059f57c4ecd193dd438d784fb28b074b68317c6c
1af93bd370222e79bebe9f970a392673275e9af6
36a692825f34a0c0702e3e6bc87c4bb2a01bc22a
64b4ed717721833569014abf41095adad939fcf4
455e68f8dbd734833621902ca3475aa9fb5a8501
We cannot bisect more!

(Note that due to the merge commit structure we cannot build openai commits in isolation.)

I am investigating the above set of commits now to see if I can narrow down the issue.

alexbaden commented 1 month ago

This is the bad commit: https://github.com/intel/intel-xpu-backend-for-triton/commit/059f57c4ecd193dd438d784fb28b074b68317c6c#diff-00e5b15068049beb083ce4359833adf047e65d36691c575685a5995277ff26c6R839

etaf commented 1 month ago

Hi, @alexbaden I double checked that the commit dd07225 is faild, sorry for misleading your debug.

vlad-penkin commented 1 month ago

The accuracy issue is reproducible for this Triton kernel (TorchInductor generated for DebertaForQuestionAnswering backward pass program):

@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr3, out_ptr4, xnumel, rnumel):
    xnumel = 512
    XBLOCK: tl.constexpr = 1
    rnumel = 768
    RBLOCK: tl.constexpr = 1024
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask & xmask, other=0.0)
    tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask & xmask, other=0.0)
    tmp3 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask & xmask, other=0.0)
    tmp7 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
    tmp37 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask & xmask, other=0.0).to(tl.int1)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 * tmp3
    tmp5 = -tmp4
    tmp8 = tmp6 / tmp7
    tmp9 = tmp8 / tmp7
    tmp10 = tmp5 * tmp9
    tmp11 = tl.broadcast_to(tmp10, [RBLOCK])
    tmp13 = tl.where(rmask & xmask, tmp11, 0)
    tmp14 = triton_helpers.promote_to_tensor(tl.sum(tmp13, 0))
    tmp15 = tmp4 / tmp7
    tmp16 = -tmp15
    tmp17 = tl.broadcast_to(tmp16, [RBLOCK])
    tmp19 = tl.where(rmask & xmask, tmp17, 0)
    tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
    tmp21 = 2.0
    tmp22 = tmp7 * tmp21
    tmp23 = tmp14 / tmp22
    tmp24 = 0.0013020833333333333
    tmp25 = tmp23 * tmp24
    tmp26 = tmp6 * tmp21
    tmp27 = tmp25 * tmp26
    tmp28 = -tmp27
    tmp29 = tl.broadcast_to(tmp28, [RBLOCK])
    tmp31 = tl.where(rmask & xmask, tmp29, 0)
    tmp32 = triton_helpers.promote_to_tensor(tl.sum(tmp31, 0))
    tmp33 = tmp15 + tmp27
    tmp34 = tmp20 + tmp32
    tmp35 = tmp34 * tmp24
    tmp36 = tmp33 + tmp35
    tmp38 = 0.0
    tmp39 = tl.where(tmp37, tmp38, tmp36)
    tmp40 = 1.1111111111111112
    tmp41 = tmp39 * tmp40
    tl.store(out_ptr3 + (r1 + (768*x0)), tmp36, rmask & xmask)
    tl.store(out_ptr4 + (r1 + (768*x0)), tmp41, rmask & xmask)