ROCm / triton

Development repository for the Triton language and compiler
MIT License
80 stars 23 forks source link

Merage aot features #460

Closed groenenboomj closed 5 months ago

xinyazhang commented 6 months ago

The PR needs some major reversions in order to make it as the unified one b/w AOTrition

  1. Use ALL UPPER CASE for all tl.constexpr input arguments;
  2. extra_tokens_n is not tl.constexpr. tl.constexpr means an extra template argument and exponential growth of AOTriton kernels.
  3. bias_type should not be a string. I don't know how to pass string as tl.constexpr now. I can get to know but let's not create extra troubles by comparing strings in GPU.
  4. Do not check input values to enable/disable functionalities. Notably, do not use if bias_ptr is not None:, but if ENABLE_BIAS and add ENABLE_BIAS to the argument list;
  5. Do not add underscore before the kernel entrance, e.g. _attn_fwd -> attn_fwd. For AOTriton these functions are public.
  6. Please move anything that needs torch out of this file. It's creating circular dependencies again.
    • flash_fwd.py: for forward kernel
    • flash_bwd.py: for backward kernel
    • test_flash.py: for testing
    • perf_flash.py: for autotune
    • flash_common.py: constants and common functions (e.g. dropout_mask)
groenenboomj commented 6 months ago

The PR needs some major reversions in order to make it as the unified one b/w AOTrition

  1. Use ALL UPPER CASE for all tl.constexpr input arguments;
  2. extra_tokens_n is not tl.constexpr. tl.constexpr means an extra template argument and exponential growth of AOTriton kernels.
  3. bias_type should not be a string. I don't know how to pass string as tl.constexpr now. I can get to know but let's not create extra troubles by comparing strings in GPU.
  4. Do not check input values to enable/disable functionalities. Notably, do not use if bias_ptr is not None:, but if ENABLE_BIAS and add ENABLE_BIAS to the argument list;
  5. Do not add underscore before the kernel entrance, e.g. _attn_fwd -> attn_fwd. For AOTriton these functions are public.
  6. Please move anything that needs torch out of this file. It's creating circular dependencies again.

Most of these are reasonable. Let's discuss at the sync and I can get most of this in with the potential split.