Open jgong5 opened 1 year ago
Repro with triton commit: af76c989eb4799b015f8b288ccd8421558772e56
cc @jansel
I don't think it's technically a bug. The auto-tuner will just run all the configs one by one and see which one is fastest. In this case, you have:
tmp7 = tl.load(in_out_ptr0 + (r2 + (384*x3)), xmask & rmask, eviction_policy='evict_last')
...
tl.store(in_out_ptr0 + (r2 + (384*x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp19, xmask & rmask)
so when the kernel runs twice, the first call modifies in_out_ptr0
which leads to the wrong value being loaded in the second call.
The same issue arised when doing atomic_add
in a matmul with SPLIT_K>1, and we fixed the issue by adding a hook in the auto-tuner that resets the desired arguments to zeros between runs:
https://github.com/openai/triton-mlir/blob/main/python/triton/runtime/autotuner.py#L25-L33
I imagine in your case we'd want the hook to initialize data from a copy.
Re-tagging this as an enhancement as we probably want to revisit the autotuner API and document this behavior properly. I don't think it'll be possible to have a general analysis pass that determines which arguments should be copied upfront -- this would work in simple cases but likely would fail when aliasing behavior is dynamic.
Ah I think you are correct. The autotuner runs the kernel multiple times, which is not safe to do for in-place kernels because the first run clobbers the input data.
Note TorchInductor has a subclass of the autotuner, so for this case we would need to modify https://github.com/pytorch/torchdynamo/blob/6380959be21851bfda99424392cc08fda29d073d/torchinductor/triton_ops/autotune.py#L145 to create copies of all the input tensors using: https://github.com/pytorch/torchdynamo/blob/6380959be21851bfda99424392cc08fda29d073d/torchinductor/compile_fx.py#L145
@ptillet Thanks for the insights. Appreciate if you can update the doc about the limitation and the recommendation to mitigate the problem before you provide a complete solution. Memory-related issues are hard to troubleshoot. If Triton can report some warnings via some debugging flag, that would make user's life easier.
Repro:
Expected result: no assertion occurs but got failure on the second assertion. If only one
triton.Config
is left (i.e. only onetriton.Config({'XBLOCK': 1, 'RBLOCK': 64})
), the test can pass. Check https://github.com/pytorch/torchdynamo/issues/1670 for more background info.