LAION-AI / Open-Assistant

OpenAssistant is a chat-based assistant that understands tasks, can interact with third-party systems, and retrieve information dynamically to do so.
https://open-assistant.io
Apache License 2.0
36.94k stars 3.22k forks source link

Experiment: LLongMA with dilated attention #3557

Closed andreaskoepf closed 1 year ago

andreaskoepf commented 1 year ago

RoPE scaling has been shown to be effective by @jquesnelle, @conceptofmind, Kaiokendev and NousResearch who recently released (tweet) a set of long-context fine-tuned models of open-llama called LLongMA (3b, 7b, 13b). They publish training and evaluation code in jquesnelle/scaled-rope.

The following plot (shared in #sft of the CarperAI discord) nicely shows the effectiveness of LLongMA fine-tuning: image

One remaining problem of position-interpolated long-context open-source models (especially if we use the larger ones) is slow inference (low number of token/s). A possible solution would be to use sparse-attention like the dilated-attention that was recently described by Microsoft Research in the LongNet paper.

We will conduct a basic experiment to see if fine-tuning a LongMA model (with a limited training budget of ~1B tokens) for dilated attention might be feasible.

andreaskoepf commented 1 year ago

Intermeditae update: The adaption of a pre-trained model to dilated-attn seems to be slow. I have so far looked at fine-tuning runs of the LLongMA-3b model with up to ~1k steps. The loss falls quickly to the range where it would be with blocksize as context with regular attention but it only slowly improves further.

Selected loss curves: image

I'll prepare some more experiments with smaller (pythia) models to (again) verify my dilated-attn impl and to better estimate how long an adaption would take.

andreaskoepf commented 1 year ago

Due to lack of time and better working options for pre-trained models I decided to discontinue the dilated-attention experiments.

younesselbrag commented 8 months ago

have you implemented Dilated Attention with LlongMA model ?? if you have please inference could be shared

this much-Head which i tried to use flash Attention also for Efficiency memory


class MultiheadDilatedAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dilation_rates: Sequence[int],
        segment_lengths: Sequence[int],
        dropout: float = 0.0,
        bias: bool = True,
        layer_norm: bool = True,
        layer_norm_eps: float = 1e-5,
        gamma_init: float = 1.0,
        device: Optional[Union[torch.device, str]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.layer_norm = layer_norm
        self.gamma_init = gamma_init

        if not embed_dim % self.num_heads == 0:
            raise ValueError(
                f"embed_dim ({embed_dim}) must be divisible by num_heads"
                f" ({num_heads})"
            )
        num_dilations = len(dilation_rates)
        num_segments = len(segment_lengths)
        if num_dilations != num_segments:
            raise ValueError(
                f"len(dilation_rates) ({num_dilations}) must be equal to "
                f"len(segment_lengths) ({num_segments})"
            )
        head_dim = embed_dim // num_heads
        if not head_dim % 8 == 0:
            raise ValueError(
                f"head_dim (embed_dim / num_heads = {head_dim}) must be"
                " divisible by 8"
            )
        if not head_dim <= 128:
            raise ValueError(
                f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
            )

        self.q_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.k_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.v_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.attention = DilatedAttention(
            d_model=embed_dim,
            num_heads=num_heads,
            segment_size=segment_lengths,
            dilation_rate=dilation_rates,
            dropout=dropout,
            # op=op,
        )
        self.norm: Optional[nn.LayerNorm] = None
        if layer_norm:
            self.norm = nn.LayerNorm(
                embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
            )
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_normal_(self.q_proj.weight)
        if self.q_proj.bias is not None:
            nn.init.constant_(self.q_proj.bias, 0)
        nn.init.xavier_normal_(self.k_proj.weight)
        if self.k_proj.bias is not None:
            nn.init.constant_(self.k_proj.bias, 0)

        # NOTE: We follow the initialization strategy from MAGNETO.  See:
        # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
        # Gain (self.gamma_init) should be provided as a keyword argument when
        # initializing the larger Transformer model, since it requires knowledge
        # of the number of encoder/decoder layers in the model.

        nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
        if self.v_proj.bias is not None:
            nn.init.constant_(self.v_proj.bias, 0)
        nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init)
        if self.out_proj.bias is not None:
            nn.init.constant_(self.out_proj.bias, 0)

    def forward(
        self, query: Tensor, key: Tensor, value: Tensor, is_causal: bool = False
    ) -> Tuple[Tensor, None]:
        # Notation:
        #   b - batch size
        #   n - sequence length
        #   h - number of heads
        #   d - embedding dimension
        #
        # Input shape: (b, n, d)
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        # Unfold 'd' dimension into 'h' separate attention heads.
        q = rearrange(q, "b n (h d) -> b n h d", h=self.num_heads)
        k = rearrange(k, "b n (h d) -> b n h d", h=self.num_heads)
        v = rearrange(v, "b n (h d) -> b n h d", h=self.num_heads)

        # Apply attention, then fold 'h' attention heads back into 'd'.
        x = self.attention(q, k, v, causal=is_causal)
        x = rearrange(x, "b n h d -> b n (h d)")

        if self.layer_norm:
            assert self.norm is not None
            x = self.norm(x)

        # Linear projection on attention outputs.
        x = self.out_proj(x)

        return x, None