ofirpress / attention_with_linear_biases

Code for the ALiBi method for transformer language models (ICLR 2022)
MIT License
497 stars 38 forks source link

ALiBi in Parallel Attention #8

Closed conceptofmind closed 2 years ago

conceptofmind commented 2 years ago

Hi @ofirpress ,

I am working on implementing ALiBi in a Parallel Attention Transformer. I have removed the positional embeddings from the model. I set up a relative alibi bias matrix and calculated the slopes. I then add the alibi attention bias to the causal mask. Unfortunately, I am unable to get the correct number of trainable parameters. Is it possible to take a quick look and see if there is anything noticeably wrong in the code implementation below?

Code

Function for slopes:

def get_alibi_slopes(heads):
    def get_slopes_power_of_2(n):
        start = (2 ** (-2 ** -(log2(n) - 3)))
        ratio = start
        return [start*ratio**i for i in range(n)]

    if log2(heads).is_integer():
        return get_slopes_power_of_2(heads)

    closest_power_of_2 = 2 ** floor(log2(heads))
    return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

Calculate the alibi bias:

def calc_alibi_bias(seq_len, heads):
    slopes = torch.Tensor(get_alibi_slopes(heads))
    slopes = rearrange(slopes, 'h -> h 1 1')
    bias = rearrange(torch.arange(seq_len), 'j -> 1 1 j')
    return slopes * bias

Build the Parallel Attention Block:

class ParallelAttentionBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = RMSNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        self.heads = heads
        self.scale = dim_head**-0.5

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # for caching causal mask
        self.register_buffer("mask", None, persistent=False)

    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), 1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    def forward(self, x):
        n, device, h = x.shape[1], x.device, self.heads

        # pre layernorm
        x = self.norm(x)

        # attention queries, keys, values, and feedforward inner
        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # split heads
        q = rearrange(q, "b n (h d) -> b h n d", h = h)

        # scale
        q = q * self.scale

        # similarity
        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # causal mask
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # alibi bias
        alibi_bias = calc_alibi_bias(n, heads = h)
        attn_bias = repeat(alibi_bias, 'h 1 j -> h i j', i = n)
        attn_bias = attn_bias[..., :n, :n]

        # add the bias matrix to the mask
        sim = sim + attn_bias

        # attention
        attn = sim.softmax(dim=-1)

        # aggregate values
        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # merge heads
        out = rearrange(out, "b h n d -> b n (h d)")

        merge_heads = self.attn_out(out) + self.ff_out(ff)
        return merge_heads

Any help would be greatly appreciated!

Thank you,

Enrico

ofirpress commented 2 years ago

Sorry, I'm not sure what the problem is. ALiBi should be re-implemented by HuggingFace soon (since it's a component in the BigScience BLOOM model) so you can check out their code too.

conceptofmind commented 2 years ago

Hi @ofirpress,

Thank you for taking a look.

I ended up resolving the issue by using a custom Alibi class with the help of a peer.

Best,

Enrico

For anyone who is interested:

# AliBi

class AlibiPositionalBias(nn.Module):
    def __init__(self, heads, **kwargs):
        super().__init__()
        self.heads = heads
        slopes = torch.Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        self.register_buffer('slopes', slopes, persistent = False)
        self.register_buffer('bias', None, persistent = False)

    def get_bias(self, i, j, device):
        i_arange = torch.arange(i, device = device)
        j_arange = torch.arange(j, device = device)
        bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
        return bias

    @staticmethod
    def _get_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** floor(log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

    def forward(self, qk_sim):
        h, i, j, device = *qk_sim.shape[-3:], qk_sim.device

        if exists(self.bias) and self.bias.shape[-1] >= j:
            return qk_sim + self.bias[..., :i, :j]

        bias = self.get_bias(i, j, device)
        bias = bias * self.slopes

        num_heads_unalibied = h - bias.shape[0]
        bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
        self.register_buffer('bias', bias, persistent=False)

        return bias