Closed switiz closed 1 year ago
Hi,
Thanks for the question. Yes, kerple is just biases imposed on the self attention matrix before softmax.
Regarding complex model parallel code => You can simply 1) remove all the modelparellel* variables and 2) set num_heads_per_partition = num_attention_heads because the number of partitions is always 1 if model parallel is disabled.
"#-> maybe here is add point" is the correct place to add the kerple bias. Copying
bias = -self.bias_p*torch.log(1+self.bias_a*diff) # log kernel
and all the related lines/functions to there should do the trick.
This might need some trial and error. Please feel free to come back if the above suggestions don't work.
Dear chijames
It seems to work fine when applying simple version of kerple log as we discussed.
When using huggingface gpt2 to train wikitext-103 with max position 1024 When using learned position (original), 106 is measured when measuring test ppl with 1024.
When using kerple log, 1024 length-75.84, 2048 length - 73.64, 4096 length - 72.59 It is maintained even when increasing the inference length than trained length and shows better performance than the original learned position.
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
if self.cached_seq_len != key_length:
diff = torch.tril(
torch.arange(key_length, device=query.device).view(key_length, 1).repeat(1, key_length)
+ torch.arange(0, -key_length, -1, device=query.device)
)
diff = diff.to(query.dtype)
self.cached_seq_len = key_length
self.cached_matrix = diff
else:
diff = self.cached_matrix
kerple_bias = -self.kerple_bias_p*torch.log(1+self.kerple_bias_a*diff) # log kernel
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights += kerple_bias
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
thank you
Sweet! Thanks for sharing the numbers with us.
There is a potential bug in the "modify" implementation. We need additional constraints on bias_a and bias_p:
self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
to make sure they are always positive. It would be better if you safeguard them from going to negative.
It's very lucky that your experiments worked because if these two parameters become negative during the training process, kerple will no longer maintain the CPD property. For more information, please refer to Corollary 1 in sec. 3.2 and Practical Choice. in sec. 4 of our paper:
bias_a = r_2 and bias_p = r_1. c can be safely ignored in the implementation.
Thank you for your kind recommendation.
I performed the model train again by applying the clamp, and the result was similar to the previous one. Maybe I was the lucky guy. :) In the final implementation, I will implement and apply it according to your guide. thank you
One thing I'm curious about is that your method seems to be based on the causal mask, but I wonder if it can be applied to a bidirectional transformer, for example, bert or encoder/decoder model like t5
kerple have a learnable parameter, so it might be possible, but have you ever tried it?
Because in the case of t5 relative position, it is also effective in bi direction, but it is slow so I wonder if it can be replaced. thank you
T5 has three attention mechanisms: encoder-only, cross, and decoder-only. The cross attention mechanism doesn't seem to have positional embeddings in its original form (?), so I think there is no need to add kerple to it. As for the encoder-only mechanism (also like BERT), I think kerple is definitely applicable after we discard the causal mask. I haven't tried it though.
First, Thanks for sharing your great research.
I have reviewed the paper and the code, and it appears to be a form of adding kerple bias to the attention score.
However, since the code is in neox framework, it is difficult to understand actual behavior due to the model parallel code. Can you share a simple version of the code?
For example, what would be the simplest way to add the kerple log class below to attached neox attention function?
kerplelog class
neox attention
huggingface neox attention