ROCm / composable_kernel

Composable Kernel: Performance Portable Programming Model for Machine Learning Tensor Operators
https://rocm.docs.amd.com/projects/composable_kernel/en/latest/
Other
297 stars 113 forks source link

[CK Tile] Generic attention masking support for FMHA fwd and bwd #1340

Open cameronshinn opened 3 months ago

cameronshinn commented 3 months ago

The goal of these changes is to support generic attention masks for the FMHA operator in CK tile. The motivation is to support a variety of masking strategies from existing research (beyond just the existing causal masking). The two that I aim to add support for are:

Big Bird

Longformer

With the existing masking from SimplifiedGenericAttentionMask, it was only possible to create masks as a diagonal window, which can only let us do windowed attention or causal attention. Additionally, the FMHA fwd/bwd pipelines make the assumption that the masked tiles in a tile row (column for backwards) are contiguous. This can't support the Big Bird and Longformer masks.

These changes instead let the main mask interface, GenericAttentionMask, accept a mask definition object, which is where the mask-specific details are contained. A different mask definition can be passed in for different kinds of masks. For example, DiagonalMask mimics the previous method of windowed masking. The required signature of a mask definition is laid out in the MaskDefABC struct.

Masks can also be defined at a tile granularity instead of a per-element granularity, signified with IsTileMask. This is helpful since Big Bird uses block sparsity. Tile sizes need to be members of the struct somehow, and I found it easier to make them template parameters (x_tile, y_tile).

The pipelines have been modified to use an IndexIterator to skip to the next non-zero tile, since they can now be non-contiguous. The index iterator loops through the tile mask indices, checking through incrementing indices until a non-zero tile is found.

From what I can tell, the tradeoffs are such:

  1. 😊 Construct a variety of mask types easily
  2. 😊 Scalable to arbitrary sequence lengths
  3. 😊 Mask is defined in instructions rather than sparse data structure arrays
  4. ☹️ Mask size is unknown without evaluating the predicate across the entire attention matrix index space
  5. ☹️ Next non-zero in a row can't be determined without checking every index in-between

I am opening this as a draft PR to initiate any discussion. I currently am still working on adding in mask definitions for Big Bird and Longformer as well as some performance results to show (verifying that there isn't any perf regression for the existing masking).