Open etaf opened 1 month ago
I was able to reproduce the failure with dd07225
. However, b8c64f4
passed.
Bisecting produced the following result:
There are only 'skip'ped commits left to test.
The first bad commit could be any of:
059f57c4ecd193dd438d784fb28b074b68317c6c
1af93bd370222e79bebe9f970a392673275e9af6
36a692825f34a0c0702e3e6bc87c4bb2a01bc22a
64b4ed717721833569014abf41095adad939fcf4
455e68f8dbd734833621902ca3475aa9fb5a8501
We cannot bisect more!
(Note that due to the merge commit structure we cannot build openai commits in isolation.)
I am investigating the above set of commits now to see if I can narrow down the issue.
Hi, @alexbaden I double checked that the commit dd07225
is faild, sorry for misleading your debug.
The accuracy issue is reproducible for this Triton kernel (TorchInductor generated for DebertaForQuestionAnswering backward pass program):
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr3, out_ptr4, xnumel, rnumel):
xnumel = 512
XBLOCK: tl.constexpr = 1
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask & xmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (768*x0)), rmask & xmask, other=0.0)
tmp3 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.load(in_ptr3 + (r1 + (768*x0)), rmask & xmask, other=0.0)
tmp7 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
tmp37 = tl.load(in_ptr5 + (r1 + (768*x0)), rmask & xmask, other=0.0).to(tl.int1)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 * tmp3
tmp5 = -tmp4
tmp8 = tmp6 / tmp7
tmp9 = tmp8 / tmp7
tmp10 = tmp5 * tmp9
tmp11 = tl.broadcast_to(tmp10, [RBLOCK])
tmp13 = tl.where(rmask & xmask, tmp11, 0)
tmp14 = triton_helpers.promote_to_tensor(tl.sum(tmp13, 0))
tmp15 = tmp4 / tmp7
tmp16 = -tmp15
tmp17 = tl.broadcast_to(tmp16, [RBLOCK])
tmp19 = tl.where(rmask & xmask, tmp17, 0)
tmp20 = triton_helpers.promote_to_tensor(tl.sum(tmp19, 0))
tmp21 = 2.0
tmp22 = tmp7 * tmp21
tmp23 = tmp14 / tmp22
tmp24 = 0.0013020833333333333
tmp25 = tmp23 * tmp24
tmp26 = tmp6 * tmp21
tmp27 = tmp25 * tmp26
tmp28 = -tmp27
tmp29 = tl.broadcast_to(tmp28, [RBLOCK])
tmp31 = tl.where(rmask & xmask, tmp29, 0)
tmp32 = triton_helpers.promote_to_tensor(tl.sum(tmp31, 0))
tmp33 = tmp15 + tmp27
tmp34 = tmp20 + tmp32
tmp35 = tmp34 * tmp24
tmp36 = tmp33 + tmp35
tmp38 = 0.0
tmp39 = tl.where(tmp37, tmp38, tmp36)
tmp40 = 1.1111111111111112
tmp41 = tmp39 * tmp40
tl.store(out_ptr3 + (r1 + (768*x0)), tmp36, rmask & xmask)
tl.store(out_ptr4 + (r1 + (768*x0)), tmp41, rmask & xmask)
After upgrade triton commit pin from
b8c64f64c18d8cac598b3adb355c21e7439c21de
( currently stock Pytorch in used) to514e4cdf004278c82216364d1f8534b940cd4238
(2.4 release candidate), We found the two model accruacy error in the same Pytorch version (every thing is the same except for triton):I've also tried the earliest 2.4 release candidate commit id: dd072259e1df5c0cd42034d90ec5f464f22e5ae3 and passed, hope this will reduce the scope of your regrssion searches.
To reproduce:
use the stock pytorch main branch plus this PR: https://github.com/pytorch/pytorch/pull/126516 Test script: https://github.com/intel/torch-xpu-ops/blob/main/.github/scripts/inductor_xpu_test.sh driver: 803.29 command: