triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.42k stars 1.64k forks source link

Perf (`priority: low`): Explore utilizing I32 indexing for `make_block_ptr` with grid size compilation guards #2324

Open jon-chuang opened 1 year ago

jon-chuang commented 1 year ago

Currently it forces use of I64 in order to avoid overflow. However, I actually believe it should be possible to infer if I32 indexing will overflow based on grid size and require recompilation/alternate dispatch. It may help in those cases when there is severe register pressure.

Originally posted by @jon-chuang in https://github.com/openai/triton/issues/2301#issuecomment-1722488961

The ptrAdd(ptr, offset) will ultimately happen in i64/u64. When combined with https://github.com/openai/triton/pull/2314, this will reduce the need for i64 registers by performing the offset arithmetic in I32.

Handling I32 Overflows

One can create a compilation guard on the kernel grid size as a form of delayed interval analysis to ensure that we will never overflow.

The methodology would be to use a torch.SymInt style tracing to trace all of the arithmetic expressions involved in offset computations, and guard on whether a given grid size will cause that arithmetic expression to overflow.

We can consider I32 indexing the happy path (we assume that buffers are < 4GB in most scenarios) and throw an error if overflow is detected. We can also have a use_i32_indexing=True config param. The user can be advised to turn it to False if they encounter the error.

Note: this tracing/guard does not need to be in the backend. However, regardless of using frontend or backend trace/guard/dispatch, it is difficult to integrate this with upstream such as TorchInductor/JAX Pallas.

Motivation

It was measured in https://github.com/openai/triton/issues/2301 that seemingly when there is severe register pressure, using I32 instead of I64 seems to help. More investigation is required on this front.

jon-chuang commented 1 year ago

Actually, it easier to simply expose use_int32_arith as a param to make_block_ptr. But again, it's a corner case and hard to tell users to use it.