Closed jataylo closed 1 year ago
It does look like int8 again:
tmp3 = tl_any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
and then casting to int1?
The previous fix promoted to int32 since its free and more correct anyway -but maybe there are gaps in that fix...
It does look like int8 again:
tmp3 = tl_any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
and then casting to int1?
Yeah the upstream comment mentions they cast to tl.int8 as tl.reduce doesn't support tl.int1 and then cast back to tl.int1 afterwards https://github.com/pytorch/pytorch/commit/94d306fd4594a642ac70f3f5d2a9eb4910587774#diff-ea3834aeda2e22f49dd5ba0becd456b449b10d17dfe64c34e89072701c82873dR1285
It does look like int8 again:
tmp3 = tl_any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
and then casting to int1?Yeah the upstream comment mentions they cast to tl.int8 as tl.reduce doesn't support tl.int1 and then cast back to tl.int1 afterwards pytorch/pytorch@94d306f#diff-ea3834aeda2e22f49dd5ba0becd456b449b10d17dfe64c34e89072701c82873dR1285
Ah ok that makes sense
@jayfurmanek @zhanglx13 @alefimov-amd cc: @jithunnair-amd @dllehr-amd
This error is observed with TOT triton-mlir and causes some huggingface models to break along with some PyTorch UTs. Likely to be a blocking issue on us moving PyTorch commit forward.
Experiments: I have found this seems to occur specifically with
num_warps=2
if we change the reproducer triton.compile line tonum_warps=4
then the error is not observed:test = triton.compile(triton_fn, signature="*i1,*fp32,*i1,i32,i32", constants={"XBLOCK": 1, "RBLOCK": 128}, num_warps=4)
Full parameters (failing):
Reproducer:
Note: @jayfurmanek mentioned this commit previously fixed a similar error https://github.com/ROCmSoftwarePlatform/triton/pull/347/commits/a41f13adcd4ae623d69f77290008fa1d7b10b490
This may be a potentially relevant PyTorch PR https://github.com/pytorch/pytorch/commit/94d306fd4594a642ac70f3f5d2a9eb4910587774 which goes over some of the accumulator logic for aten.any, but this PR is pretty old at this point.