ROCm / aotriton

Ahead of Time (AOT) Triton Math Library
MIT License
42 stars 15 forks source link

Switch to upstream Triton compiler, and related changes #36

Closed xinyazhang closed 4 months ago

xinyazhang commented 4 months ago
  1. Switch to performance kernel for forward pass. The old Triton kernel does not work with new compiler
  2. Support AOT based autotune, which includes
    1. Add argument aotriton::v2::flash::ExtraArguments to all aotriton::v2::flash APIs
    2. Add build option AOTRITON_BUILD_FOR_TUNING to build all possible GPU kernels. The configurations are supplied by KernelDescription.gen_autotune_configs, which is compatible with triton.Config.
    3. AOTRITON_BUILD_FOR_TUNING also enables force_kernel_index and other fields to aotriton::v2::flash::ExtraArguments. Users can manually select kernel and bypass the autotune mechanism.
    4. Add test/tune_flash.py cpp_autotune.py and change test/attn_torch_function.py to support AOT autotune (aka cpp autotune)
      • The test/tune_flash.py will run UT before testing a triton.Config's performance, to avoid including faulty kernels.
  3. Add Navi31/32 compiler options (but not added to the default config due to compiler problems)
  4. Add --use_multigpu to test/tune_flash.py. Now this script support tuning GPU kernels on all GPUs simultaneously, and the following extra features:
    • It also put the UT to a separate process (referred as minesweeper process here), in case the faulty kernel triggers a segfault and crashes the worker process.
      • Thus the tune_flash.py needs 1(main)+n(worker)+n(minesweeper)+1(db access)+1*(table_tool.py) processes
      • For better performance, the minesweeper process is reused and only get recreated if the previous one hit segfault (or other failures).
    • --json_file is also added since the new architecture has a unified database access process that accept outputs from all worker processes, and this new process can write to a separate json file. This is current recommended way to store the result of tuning script. Users are supposed to run v2python.table_tool later to update the tuning database.
    • --continue_from_json_file is introduced. Meanwhileresultand_debug_task_idfields are also attached to the output json object, so that a tuning process can be resumed according to the_debug_task_id` and its tuning status
    • v2python.table_tool is improved to support the new version of json file
  5. Tuning results of the forward kernel are updated for MI200/MI300X +new compiler. Most UTs passed (see comments for known failures on MI300X)

CAVEAT: The new AOT based autotune script test/tune_flash.py isn't capable of handling backward pass yet.

xinyazhang commented 4 months ago

Known failures:

test_op_bwd_with_matrix_bias[False-1.2-dtype2-0.0-2048-143-256-4-4]
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True

Produced by pytest test/test_backward.py -v -k 1.2 on MI300X

xinyazhang commented 4 months ago

The UT on MI200 has better results:

FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True

Tested with pytest test/test_backward.py -v -k 1.2

xinyazhang commented 4 months ago

Looks mostly good. What is the new library size?

I don't have the all architecture+no zstd version size. The MI300X only+zstd size is 321M