Closed andreaskoepf closed 1 year ago
Intermeditae update: The adaption of a pre-trained model to dilated-attn seems to be slow. I have so far looked at fine-tuning runs of the LLongMA-3b model with up to ~1k steps. The loss falls quickly to the range where it would be with blocksize as context with regular attention but it only slowly improves further.
Selected loss curves:
llongma_baseline
is training with 8k contextorig_1024
is longma training with 1k contextorig_2048
is longma training with 2k contextdilated_8k_lr5_fix1
is training with dilated attention patch (blocks: [1k, 2k, 4k, 8k], dilations: [1,2,4,8])I'll prepare some more experiments with smaller (pythia) models to (again) verify my dilated-attn impl and to better estimate how long an adaption would take.
Due to lack of time and better working options for pre-trained models I decided to discontinue the dilated-attention experiments.
have you implemented Dilated Attention with LlongMA model ?? if you have please inference could be shared
this much-Head which i tried to use flash Attention also for Efficiency memory
class MultiheadDilatedAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dilation_rates: Sequence[int],
segment_lengths: Sequence[int],
dropout: float = 0.0,
bias: bool = True,
layer_norm: bool = True,
layer_norm_eps: float = 1e-5,
gamma_init: float = 1.0,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.num_heads = num_heads
self.layer_norm = layer_norm
self.gamma_init = gamma_init
if not embed_dim % self.num_heads == 0:
raise ValueError(
f"embed_dim ({embed_dim}) must be divisible by num_heads"
f" ({num_heads})"
)
num_dilations = len(dilation_rates)
num_segments = len(segment_lengths)
if num_dilations != num_segments:
raise ValueError(
f"len(dilation_rates) ({num_dilations}) must be equal to "
f"len(segment_lengths) ({num_segments})"
)
head_dim = embed_dim // num_heads
if not head_dim % 8 == 0:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be"
" divisible by 8"
)
if not head_dim <= 128:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
)
self.q_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.k_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.v_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.attention = DilatedAttention(
d_model=embed_dim,
num_heads=num_heads,
segment_size=segment_lengths,
dilation_rate=dilation_rates,
dropout=dropout,
# op=op,
)
self.norm: Optional[nn.LayerNorm] = None
if layer_norm:
self.norm = nn.LayerNorm(
embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
)
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_normal_(self.q_proj.weight)
if self.q_proj.bias is not None:
nn.init.constant_(self.q_proj.bias, 0)
nn.init.xavier_normal_(self.k_proj.weight)
if self.k_proj.bias is not None:
nn.init.constant_(self.k_proj.bias, 0)
# NOTE: We follow the initialization strategy from MAGNETO. See:
# https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
# Gain (self.gamma_init) should be provided as a keyword argument when
# initializing the larger Transformer model, since it requires knowledge
# of the number of encoder/decoder layers in the model.
nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
if self.v_proj.bias is not None:
nn.init.constant_(self.v_proj.bias, 0)
nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, is_causal: bool = False
) -> Tuple[Tensor, None]:
# Notation:
# b - batch size
# n - sequence length
# h - number of heads
# d - embedding dimension
#
# Input shape: (b, n, d)
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
# Unfold 'd' dimension into 'h' separate attention heads.
q = rearrange(q, "b n (h d) -> b n h d", h=self.num_heads)
k = rearrange(k, "b n (h d) -> b n h d", h=self.num_heads)
v = rearrange(v, "b n (h d) -> b n h d", h=self.num_heads)
# Apply attention, then fold 'h' attention heads back into 'd'.
x = self.attention(q, k, v, causal=is_causal)
x = rearrange(x, "b n h d -> b n (h d)")
if self.layer_norm:
assert self.norm is not None
x = self.norm(x)
# Linear projection on attention outputs.
x = self.out_proj(x)
return x, None
RoPE scaling has been shown to be effective by @jquesnelle, @conceptofmind, Kaiokendev and NousResearch who recently released (tweet) a set of long-context fine-tuned models of open-llama called LLongMA (3b, 7b, 13b). They publish training and evaluation code in jquesnelle/scaled-rope.
The following plot (shared in #sft of the CarperAI discord) nicely shows the effectiveness of LLongMA fine-tuning:
One remaining problem of position-interpolated long-context open-source models (especially if we use the larger ones) is slow inference (low number of token/s). A possible solution would be to use sparse-attention like the dilated-attention that was recently described by Microsoft Research in the LongNet paper.
We will conduct a basic experiment to see if fine-tuning a LongMA model (with a limited training budget of ~1B tokens) for dilated attention might be feasible.