pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.66k stars 514 forks source link

Question about large sequence length attention kernels #140

Open loubbrad opened 7 months ago

loubbrad commented 7 months ago

I really love this project and the accompanying blogpost, so thanks! I've reimplemented some of the inference techniques to speed up an implementation of Whisper that I am using. I had a few questions the attention kernels, as they have been giving me some some performance related issues.

By adding print statements, I can see that during the attention calculation (not including prefilling) the shapes are essentially:

k - (bs, n_heads, max_seq_len, d_head)
v - (bs, n_heads, max_seq_len, d_head)

I understand that max_seq_len is there because of the static KV cache implementation. My understanding is that due to the attention mask, the F.scaled_dot_product_attention combined with torch.compile should be able to tell that it doesn't need to calculate the attention over the entire max_seq_len. In my case however, I've found that the max_seq_len value has a big effect on the inference speed, which suggests to me that the full attention (over the entire max_seq_len context) is being performed on every iteration. This is vastly reduced when using the following context manager, as is done in generate.py:

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)

If I exclude this, I am seeing a 3x reduction in tok/s. Does this make sense? Or it is a sign I've implemented something wrong? Even with this context manager, I see a significant (50%+) increase in tok/s when I reduce the context length from 4096 to 2048 or 1024.

Thanks in advance. If it helps, here is my cuda graph friendly Whisper implementation using a static KV cache:

class ModelConfig:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int
    n_vocab: Optional[int] = None

    def set_vocab_size(self, vocab_size: int):
        self.n_vocab = vocab_size

class KVCache(nn.Module):
    def __init__(
        self,
        max_batch_size: int,
        max_seq_length: int,
        n_heads: int,
        head_dim: int,
        dtype=torch.bfloat16,
    ):
        super().__init__()
        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val, v_val: [B, H, L, D]

        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val
        v_out[:, :, input_pos] = v_val

        return k_out, v_out

def sinusoids(
    length: int, channels: int, max_timescale: float = 10000
) -> torch.Tensor:
    """Returns sinusoids for positional embedding"""
    if channels % 2 != 0:
        raise ValueError(
            f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
        )
    log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(
        -log_timescale_increment * torch.arange(channels // 2)
    )
    scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
    return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)

class EncoderAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        assert n_state % n_head == 0, "n_head does not evenly devide n_state"

        self.n_head = n_head
        self.d_head = n_state // n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(
        self,
        xa: Tensor,
    ):
        q = self.query(xa)
        k = self.key(xa)
        v = self.value(xa)

        # Reshape for correct format
        batch_size, source_seq_len, _ = k.shape
        batch_size, target_seq_len, _ = q.shape
        q = q.view(
            batch_size, target_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        k = k.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        v = v.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        wv = F.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            is_causal=False,
        )
        wv = wv.transpose(1, 2).reshape(
            batch_size,
            target_seq_len,
            self.n_head * self.d_head,
        )

        return self.out(wv)

class CrossAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        assert n_state % n_head == 0, "n_head does not evenly devide n_state"

        self.n_head = n_head
        self.d_head = n_state // n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        self.kv_cache: KVCache | None = None

    def get_kv(self, xa: torch.Tensor, xa_input_pos: Tensor):
        assert self.kv_cache is not None, "No kv_cache"
        k = self.key(xa[:, xa_input_pos])
        v = self.value(xa[:, xa_input_pos])

        # Reshape for correct format
        batch_size, source_seq_len, _ = k.shape
        k = k.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        v = v.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=xa_input_pos)

        return k, v

    def forward(
        self,
        x: Tensor,
        xa: Tensor,
        xa_input_pos: Tensor,
    ):

        q = self.query(x)
        batch_size, target_seq_len, _ = q.shape
        q = q.view(
            batch_size, target_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        k, v = self.get_kv(xa, xa_input_pos)
        wv = F.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            is_causal=False,
        )
        wv = wv.transpose(1, 2).reshape(
            batch_size,
            target_seq_len,
            self.n_head * self.d_head,
        )

        return self.out(wv)

class CausalSelfAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        assert n_state % n_head == 0, "n_head does not evenly devide n_state"

        self.n_head = n_head
        self.d_head = n_state // n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        self.kv_cache: KVCache | None = None

    def get_kv(self, x: torch.Tensor, input_pos: torch.Tensor):
        # Self attn
        k = self.key(x)
        v = self.value(x)

        # Reshape
        batch_size, source_seq_len, _ = k.shape
        k = k.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        v = v.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos)

        return k, v

    def forward(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
        input_pos: Optional[Tensor] = None,
    ):
        q = self.query(x)

        batch_size, target_seq_len, _ = q.shape
        q = q.view(
            batch_size, target_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        k, v = self.get_kv(x, input_pos=input_pos)
        wv = F.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            attn_mask=mask,
        )

        # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
        wv = wv.transpose(1, 2).reshape(
            batch_size, target_seq_len, self.n_head * self.d_head
        )

        return self.out(wv)

class EncoderAttentionBlock(nn.Module):
    def __init__(
        self, n_state: int, n_head: int, cross_attention: bool = False
    ):
        super().__init__()
        self.attn = EncoderAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(
        self,
        xa: Tensor,
    ):
        xa = xa + self.attn(
            self.attn_ln(xa),
        )
        xa = xa + self.mlp(self.mlp_ln(xa))

        return xa

class DecoderAttentionBlock(nn.Module):
    def __init__(
        self, n_state: int, n_head: int, cross_attention: bool = False
    ):
        super().__init__()
        self.attn = CausalSelfAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)
        self.cross_attn = (
            CrossAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Tensor,
        mask: Optional[Tensor] = None,
        x_input_pos: Optional[Tensor] = None,
        xa_input_pos: Optional[Tensor] = None,
    ):
        x = x + self.attn(
            self.attn_ln(x),
            mask=mask,
            input_pos=x_input_pos,
        )
        x = x + self.cross_attn(self.cross_attn_ln(x), xa, xa_input_pos)
        x = x + self.mlp(self.mlp_ln(x))

        return x

class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(
            n_state, n_state, kernel_size=3, stride=2, padding=1
        )
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[EncoderAttentionBlock] = nn.ModuleList(
            [EncoderAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = nn.LayerNorm(n_state)

    def forward(self, xa: Tensor):
        xa = F.gelu(self.conv1(xa))
        xa = F.gelu(self.conv2(xa))
        xa = xa.permute(0, 2, 1)

        assert (
            xa.shape[1:] == self.positional_embedding.shape
        ), f"incorrect audio shape: {xa.shape[1:]} != {self.positional_embedding.shape}"
        xa = (xa + self.positional_embedding).to(xa.dtype)

        for block in self.blocks:
            xa = block(xa)

        xa = self.ln_post(xa)
        return xa

class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[DecoderAttentionBlock] = nn.ModuleList(
            [
                DecoderAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = nn.LayerNorm(n_state)
        self.register_buffer("causal_mask", None, persistent=False)

    def forward(
        self,
        x: Tensor,
        xa: Tensor,
        x_input_pos: Tensor,
        xa_input_pos: Tensor,
    ):
        mask = self.causal_mask[None, None, x_input_pos]
        x = self.token_embedding(x) + self.positional_embedding[x_input_pos]

        for block in self.blocks:
            x = block(
                x=x,
                xa=xa,
                mask=mask,
                x_input_pos=x_input_pos,
                xa_input_pos=xa_input_pos,
            )

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

    def setup_cache(
        self,
        batch_size,
        max_seq_len=4096,
        max_audio_len=1500,
    ):
        self.causal_mask = torch.tril(
            torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
        )
        # Init cache
        for b in self.blocks:
            b.attn.kv_cache = KVCache(
                max_batch_size=batch_size,
                max_seq_length=max_seq_len,
                n_heads=8,
                head_dim=64,
            ).cuda()
            b.cross_attn.kv_cache = KVCache(
                max_batch_size=batch_size,
                max_seq_length=max_audio_len,
                n_heads=8,
                head_dim=64,
            ).cuda()

class AmtEncoderDecoder(nn.Module):
    def __init__(self, dims: ModelConfig):
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )

    def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
        _buff = self.encoder(mel)
        return self.decoder(tokens, _buff)

    @property
    def device(self):
        return next(self.parameters()).device
Chillee commented 7 months ago

This is a good question! I think there's two components of this question:

  1. The default FlashAttention kernel is not very performant for decoding. See https://pytorch.org/blog/flash-decoding/ for more detail.
  2. The attention kernel we generate does not early exit depending on the mask.