ofirpress / attention_with_linear_biases

Code for the ALiBi method for transformer language models (ICLR 2022)
MIT License
505 stars 39 forks source link

How can I apply ALiBi Position Encoding into huggingface model? #11

Closed hjsg1010 closed 1 year ago

hjsg1010 commented 1 year ago

How can I apply ALiBi Position Encoding into huggingface model?

I implemented ALiBi position encoding as below code and I referenced from eleutherai/gpt-neox implementaion.

class AliBi(torch.nn.Module):
    def __init__(self, num_heads, mp_size=1, mp_rank=1):
        super().__init__()
        # megatron splits across heads, so we need to make sure each
        # head receives the correct matrix
        assert mp_size <= num_heads and mp_rank <= mp_size
        self.mp_size = mp_size
        self.mp_rank = mp_rank
        self.num_heads = num_heads
        self.slice_size = num_heads // mp_size
        self.cached_matrix = None
        self.cached_seq_len = None
        slopes = torch.Tensor(self._get_slopes(num_heads))[
            mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size
        ]
        self.register_buffer("slopes", slopes)

    def _get_slopes(self, n):
        """
        Get slopes for Alibi positional embedding
        n : int = number of heads.
        For best performance, restrict n to a power of 2.
        """

        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n))
            return (
                get_slopes_power_of_2(closest_power_of_2)
                + self._get_slopes(2 * closest_power_of_2)[0::2][
                    : n - closest_power_of_2
                ]
            )

    def forward(self, x):
        # [b, np, sq, sk]
        seq_len_q = x.shape[-2]
        seq_len_k = x.shape[-1]
        if self.cached_seq_len != seq_len_k:
            a = -torch.tril(
                torch.arange(seq_len_k).view(seq_len_k, 1).repeat(1, seq_len_k)
                + torch.arange(0, -seq_len_k, -1)
            )
            a = a.to(x.device).to(x.dtype)
            slopes = self.slopes.to(a.device).to(a.dtype)
            a = a * slopes.view(self.slopes.shape[0], 1, 1)
            self.cached_seq_len = seq_len_k
            self.cached_matrix = a
        else:
            a = self.cached_matrix

        if seq_len_q != seq_len_k:
            # In the train case x has dimensionality [b, np, sq, sk] with sq == sk
            # The number of query tokens is equal to the number of key tokens
            # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
            # In this case we use the appropriate token index of the cache matrix.
            # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
            assert (
                seq_len_q == 1
            ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"
            a = a[:, seq_len_k - 1, :].view(
                a.shape[0], 1, a.shape[2]
            )  # seq_len_k - 1 points to the last token index in the current inference batch.

        return x + a

I've tried as below code (incase I use GPT2 based model)

class GPT2AttentionAlibi(GPT2Attention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alibi_positional_bias = AliBi(self.n_head)

    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
        w = torch.matmul(q, k)

        # Apply AliBi
        w = self.alibi_positional_bias(w)

        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)

        # Mask heads if head_mask is not None
        if head_mask is not None:
            w = w * head_mask

        outputs = torch.matmul(w, v)
        return outputs

def replace_attention_layer(model):
    for block in model.transformer.h:
        attn = block.attn
        config = attn.config
        attn_new = GPT2AttentionAlibi(config)
        attn_new.load_state_dict(attn.state_dict())
        block.attn = attn_new

    return model

However, it doesn't work. Can you teach me appropriate way that apply Alibi into huggingface custom models(such as llama, polyglot etc)?

Regards.

ofirpress commented 1 year ago

Hi! ALiBi can only be used by models that were trained with it. Llama wasn't trained with ALiBi, so you can't do that.

hjsg1010 commented 1 year ago

Hi! ALiBi can only be used by models that were trained with it. Llama wasn't trained with ALiBi, so you can't do that.

oh, then how about GPT-NeoX based model? such as polyglot?

ofirpress commented 1 year ago

I believe those weren't trained with ALiBi...