Currently it forces use of I64 in order to avoid overflow. However, I actually believe it should be possible to infer if I32 indexing will overflow based on grid size and require recompilation/alternate dispatch. It may help in those cases when there is severe register pressure.
The ptrAdd(ptr, offset) will ultimately happen in i64/u64. When combined with https://github.com/openai/triton/pull/2314, this will reduce the need for i64 registers by performing the offset arithmetic in I32.
Handling I32 Overflows
One can create a compilation guard on the kernel grid size as a form of delayed interval analysis to ensure that we will never overflow.
The methodology would be to use a torch.SymInt style tracing to trace all of the arithmetic expressions involved in offset computations, and guard on whether a given grid size will cause that arithmetic expression to overflow.
We can consider I32 indexing the happy path (we assume that buffers are < 4GB in most scenarios) and throw an error if overflow is detected. We can also have a use_i32_indexing=True config param. The user can be advised to turn it to False if they encounter the error.
Note: this tracing/guard does not need to be in the backend. However, regardless of using frontend or backend trace/guard/dispatch, it is difficult to integrate this with upstream such as TorchInductor/JAX Pallas.
Motivation
It was measured in https://github.com/openai/triton/issues/2301 that seemingly when there is severe register pressure, using I32 instead of I64 seems to help. More investigation is required on this front.
Originally posted by @jon-chuang in https://github.com/openai/triton/issues/2301#issuecomment-1722488961
The
ptrAdd(ptr, offset)
will ultimately happen in i64/u64. When combined with https://github.com/openai/triton/pull/2314, this will reduce the need for i64 registers by performing the offset arithmetic in I32.Handling I32 Overflows
One can create a compilation guard on the kernel grid size as a form of delayed interval analysis to ensure that we will never overflow.
The methodology would be to use a
torch.SymInt
style tracing to trace all of the arithmetic expressions involved in offset computations, and guard on whether a given grid size will cause that arithmetic expression to overflow.We can consider I32 indexing the happy path (we assume that buffers are < 4GB in most scenarios) and throw an error if overflow is detected. We can also have a
use_i32_indexing=True
config param. The user can be advised to turn it toFalse
if they encounter the error.Note: this tracing/guard does not need to be in the backend. However, regardless of using frontend or backend
trace/guard/dispatch
, it is difficult to integrate this with upstream such as TorchInductor/JAX Pallas.Motivation
It was measured in https://github.com/openai/triton/issues/2301 that seemingly when there is severe register pressure, using I32 instead of I64 seems to help. More investigation is required on this front.