Closed groenenboomj closed 5 months ago
The PR needs some major reversions in order to make it as the unified one b/w AOTrition
- Use ALL UPPER CASE for all
tl.constexpr
input arguments;extra_tokens_n
is nottl.constexpr
.tl.constexpr
means an extra template argument and exponential growth of AOTriton kernels.bias_type
should not be a string. I don't know how to pass string astl.constexpr
now. I can get to know but let's not create extra troubles by comparing strings in GPU.- Do not check input values to enable/disable functionalities. Notably, do not use
if bias_ptr is not None:
, butif ENABLE_BIAS
and addENABLE_BIAS
to the argument list;- Do not add underscore before the kernel entrance, e.g.
_attn_fwd
->attn_fwd
. For AOTriton these functions are public.- Please move anything that needs
torch
out of this file. It's creating circular dependencies again.
Most of these are reasonable. Let's discuss at the sync and I can get most of this in with the potential split.
The PR needs some major reversions in order to make it as the unified one b/w AOTrition
tl.constexpr
input arguments;extra_tokens_n
is nottl.constexpr
.tl.constexpr
means an extra template argument and exponential growth of AOTriton kernels.bias_type
should not be a string. I don't know how to pass string astl.constexpr
now. I can get to know but let's not create extra troubles by comparing strings in GPU.if bias_ptr is not None:
, butif ENABLE_BIAS
and addENABLE_BIAS
to the argument list;_attn_fwd
->attn_fwd
. For AOTriton these functions are public.torch
out of this file. It's creating circular dependencies again.flash_fwd.py
: for forward kernelflash_bwd.py
: for backward kerneltest_flash.py
: for testingperf_flash.py
: for autotuneflash_common.py
: constants and common functions (e.g. dropout_mask)