nod-ai / SHARK-Turbine

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
91 stars 44 forks source link

[attention] Extend attention to fuse transpose #669

Closed antiagainst closed 1 month ago

antiagainst commented 3 months ago

Update 5/22: patch https://github.com/iree-org/iree/pull/17408 out; needing review. Update 5/23: working on decomposition and tiling. Patch out today or so.

Groverkss commented 3 months ago

Plan to finish it this week (Before Jun 7):

4 Jun: Land online attention (https://github.com/iree-org/iree/pull/17536) 5 Jun: Create transform script using online_attention for MFMA 6 Jun: Add indexing_maps to attention op 7 Jun: Fusions

raikonenfnu commented 1 month ago

Hey guys, quick update

  1. Indexing attention (https://github.com/iree-org/iree/pull/17864) has landed
  2. CastTypeOFitMMA support for TD pipeline (https://github.com/iree-org/iree/pull/17884) is up
  3. transfer_write distribution for non contiguous indexing map (https://github.com/iree-org/iree/pull/17895) is up

Once 2. and 3. and https://github.com/iree-org/iree/commit/d2ca77402becf4c6476893845ba96116b61df9c1 is landed on main, we should be able to handle/compile fused attn-transpose.

antiagainst commented 1 month ago

Awesome. All 3 pull requests are in. Can you send out the last piece?

raikonenfnu commented 1 month ago

Awesome. All 3 pull requests are in. Can you send out the last piece?

Hey Lei, I think @MaheshRavishankar is en route to pushing that one in! :)

MaheshRavishankar commented 1 month ago

I can send it in early next week.

raikonenfnu commented 1 month ago

I also pushed up/updated the spec mlir to find k2 correctly (link). I tested compiling on the fusion-preprocessing test MLIR (here) and was able to get a vmfb out.

The gist above is slightly different from the test in where we make the scale constant here. It fails on vector distribution if scale is not constant.

compile command:

~/nod/iree-build-notrace/tools/iree-compile constant_transpose_fusion.mlir --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-global-opt-propagate-transposes=true --iree-opt-outer-dim-concat=true --iree-opt-const-eval=false --iree-opt-data-tiling=false --iree-rocm-waves-per-eu=2 --iree-vm-target-truncate-unsupported-floats --iree-codegen-llvmgpu-use-vector-distribution --iree-codegen-gpu-native-math-precision=true --iree-flow-enable-aggressive-fusion -o attention.vmfb --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir
raikonenfnu commented 1 month ago

FYI I also tested the attention-transpose-fusion vmfb numerics on normal random numbers (0.0, 1.0) against torch, seems like we have good numerics there :)

numerics_test

Starting IR, compile command, data generator can all be found in https://gist.github.com/raikonenfnu/973b4d91e4378702ce4b4496d732cb57

Needed to update the shape from the original fusion-preprocessing test a little bit since the fastest dim for Q,K,V needs to be the same to run on pytorch.