ROCm / composable_kernel

Composable Kernel: Performance Portable Programming Model for Machine Learning Tensor Operators
https://rocm.docs.amd.com/projects/composable_kernel/en/latest/
Other
299 stars 118 forks source link

WMMA / RDNA3+ kernels for backwards fused attention? #1434

Open Googulator opened 2 months ago

Googulator commented 2 months ago

Problem Description

Composable Kernel currently only contains code to support fused attention (FA2) on RDNA3(+) architectures in the forward direction. This greatly increases the VRAM requirements for training LoRAs on LLMs using HuggingFace's Transformers and PEFT libraries - training jobs that succeed on an NVIDIA GeForce RTX 4080 with just 16GB VRAM fail on a Radeon RX 7900 XT with 20GB.

Based on https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal and https://github.com/Repeerc/sd-webui-flash-attention2-rdna3-rocm, it seems possible to implement a usable WMMA-based backwards fused attention kernel - unfortunately I can't use these myself directly, as these are both tailored for image generation (Stable Diffusion), whereas I would be interested in FA2 support for LLM training instead.

Are there any plans for adding fused attention backward pass support for RDNA3+ GPUs to CK in the foreseeable future? This seems especially pressing with the W7900 Dual Slot, an RDNA3 GPU, being recommended for AI workstation usage, where the ability to make effective use of this GPU's 48GB VRAM during training feels a lot more of a core use case.

Operating System

Ubuntu 22.04 LTS

CPU

AMD Ryzen 9 7950X (non-3D)

GPU

AMD Radeon RX 7900 XTX, AMD Radeon Pro W7900, AMD Radeon Pro W7800, AMD Radeon RX 7900 XT

Other

No response

ROCm Version

ROCm 6.0.0

ROCm Component

Composable Kernel

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

demonsan commented 13 hours ago

I was trying to understand ck_tile and preparing to write fa kernels for 7900 series. But I am confused on tile window part. In old ck we can use threadgroup and thread slice transfer, but now we have to use tile_window. The params in tile window is hard to be understood. few comments :(