chijames / KERPLE

Apache License 2.0
16 stars 1 forks source link

Ask for a simple implementation of a form that can be used in a huggingface. #2

Closed switiz closed 1 year ago

switiz commented 1 year ago

First, Thanks for sharing your great research.

I have reviewed the paper and the code, and it appears to be a form of adding kerple bias to the attention score.

However, since the code is in neox framework, it is difficult to understand actual behavior due to the model parallel code. Can you share a simple version of the code?

For example, what would be the simplest way to add the kerple log class below to attached neox attention function?

kerplelog class

class ParallelKerpleLog(torch.nn.Module):
    """Kernelized T5 Relative Position Bias parallelized in the heads dimension"""

    def __init__(
        self,
        neox_args,
    ):
        super().__init__()
        self.heads = neox_args.num_attention_heads
        self.model_parallel_size = get_model_parallel_world_size()
        self.model_parallel_rank = get_model_parallel_rank()
        self.num_heads_per_partition = self.heads // self.model_parallel_size
        self.pos_emb = neox_args.pos_emb
        self.eps = 1e-2

        # megatron splits across heads, so we need to make sure each head receives the correct matrix
        assert self.model_parallel_size <= self.heads and self.model_parallel_rank <= self.model_parallel_size

        # Allocate weights and initialize.
        # The kernel has the form -p*log(1+a*|m-n|)
        def get_parameter(scale, init_method):
            if init_method == 'ones':
                return Parameter(torch.ones(
                               self.num_heads_per_partition,
                               device=torch.cuda.current_device(),
                               dtype=neox_args.params_dtype,
                               )[:,None,None]*scale )
            elif init_method == 'uniform':
                return Parameter(torch.rand(
                               self.num_heads_per_partition,
                               device=torch.cuda.current_device(),
                               dtype=neox_args.params_dtype,
                               )[:,None,None]*scale )

        self.bias_p = get_parameter(2, 'uniform')
        self.bias_a = get_parameter(1, 'uniform')

        self.cached_matrix = None
        self.cached_seq_len = None

    def stats(self):
        def get_stats(name, obj):
            return {name+'_mean': obj.mean().detach().cpu(),
                    name+'_std': obj.std().detach().cpu(),
                    name+'_max': obj.max().detach().cpu(),
                    name+'_min': obj.min().detach().cpu()}
        dd = {}
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        dd.update(get_stats('bias_a', self.bias_a))
        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        dd.update(get_stats('bias_p', self.bias_p))
        return dd

    def forward(self, x):
        # [b, np, sq, sk]
        seq_len_q = x.shape[-2]
        seq_len_k = x.shape[-1]
        if self.cached_seq_len != seq_len_k:
            diff = torch.tril(
                torch.arange(seq_len_k, device=x.device).view(seq_len_k, 1).repeat(1, seq_len_k)
                + torch.arange(0, -seq_len_k, -1, device=x.device)
            )
            diff = diff.to(x.dtype)
            self.cached_seq_len = seq_len_k
            self.cached_matrix = diff
        else:
            diff = self.cached_matrix

        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        bias = -self.bias_p*torch.log(1+self.bias_a*diff) # log kernel

        if seq_len_q != seq_len_k:
            # In the train case x has dimensionality [b, np, sq, sk] with sq == sk
            # The number of query tokens is equal to the number of key tokens
            # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
            # In this case we use the appropriate token index of the cache matrix.
            # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
            assert (
                seq_len_q == 1
            ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"

            if type(bias) != float:
                # seq_len_k - 1 points to the last token index in the current inference batch.
                bias = bias[:, seq_len_k - 1, :].view(bias.shape[0], 1, bias.shape[2])

        return x + bias

neox attention

        if exists(self.rpe):
            if self.pos_emb.startswith("kerple"):
                attention_scores = self.rpe(attention_scores)
            else:
                rpe = self.rpe(query_layer.size(0), key_layer.size(0))
                attention_scores += rpe  # [1, np, sq, sk]

huggingface neox attention

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

        query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
        )
        attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
        #-> maybe here is add point

        mask_value = torch.finfo(attn_scores.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
        attn_scores = torch.where(causal_mask, attn_scores, mask_value)

        if attention_mask is not None:
            # Apply the attention mask
            attn_scores = attn_scores + attention_mask

        attn_weights = nn.functional.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights
chijames commented 1 year ago

Hi,

Thanks for the question. Yes, kerple is just biases imposed on the self attention matrix before softmax.

Regarding complex model parallel code => You can simply 1) remove all the modelparellel* variables and 2) set num_heads_per_partition = num_attention_heads because the number of partitions is always 1 if model parallel is disabled.

"#-> maybe here is add point" is the correct place to add the kerple bias. Copying bias = -self.bias_p*torch.log(1+self.bias_a*diff) # log kernel and all the related lines/functions to there should do the trick.

This might need some trial and error. Please feel free to come back if the above suggestions don't work.

switiz commented 1 year ago

Dear chijames

It seems to work fine when applying simple version of kerple log as we discussed.

When using huggingface gpt2 to train wikitext-103 with max position 1024 When using learned position (original), 106 is measured when measuring test ppl with 1024.

When using kerple log, 1024 length-75.84, 2048 length - 73.64, 4096 length - 72.59 It is maintained even when increasing the inference length than trained length and shows better performance than the original learned position.

original

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.finfo(attn_weights.dtype).min
            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

modify

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            if self.cached_seq_len != key_length:
                diff = torch.tril(
                    torch.arange(key_length, device=query.device).view(key_length, 1).repeat(1, key_length)
                    + torch.arange(0, -key_length, -1, device=query.device)
                )
                diff = diff.to(query.dtype)
                self.cached_seq_len = key_length
                self.cached_matrix = diff
            else:
                diff = self.cached_matrix
            kerple_bias = -self.kerple_bias_p*torch.log(1+self.kerple_bias_a*diff) # log kernel

            mask_value = torch.finfo(attn_weights.dtype).min
            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
            attn_weights += kerple_bias
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

thank you

chijames commented 1 year ago

Sweet! Thanks for sharing the numbers with us.

There is a potential bug in the "modify" implementation. We need additional constraints on bias_a and bias_p:

self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
self.bias_a.data = self.bias_a.data.clamp(min=self.eps)

to make sure they are always positive. It would be better if you safeguard them from going to negative.

It's very lucky that your experiments worked because if these two parameters become negative during the training process, kerple will no longer maintain the CPD property. For more information, please refer to Corollary 1 in sec. 3.2 and Practical Choice. in sec. 4 of our paper:

image bias_a = r_2 and bias_p = r_1. c can be safely ignored in the implementation.

switiz commented 1 year ago

Thank you for your kind recommendation.

I performed the model train again by applying the clamp, and the result was similar to the previous one. Maybe I was the lucky guy. :) In the final implementation, I will implement and apply it according to your guide. thank you

switiz commented 1 year ago

One thing I'm curious about is that your method seems to be based on the causal mask, but I wonder if it can be applied to a bidirectional transformer, for example, bert or encoder/decoder model like t5

kerple have a learnable parameter, so it might be possible, but have you ever tried it?

Because in the case of t5 relative position, it is also effective in bi direction, but it is slow so I wonder if it can be replaced. thank you

chijames commented 1 year ago

T5 has three attention mechanisms: encoder-only, cross, and decoder-only. The cross attention mechanism doesn't seem to have positional embeddings in its original form (?), so I think there is no need to add kerple to it. As for the encoder-only mechanism (also like BERT), I think kerple is definitely applicable after we discard the causal mask. I haven't tried it though.