berlino / gated_linear_attention

MIT License
97 stars 2 forks source link

advice for small sized GLA #1

Closed theodorblackbird closed 10 months ago

theodorblackbird commented 11 months ago

Thank you for this amazing work,

I'm trying to include your work as a drop-in replacement of some other SSM such as Mamba and RWKV. Note that I train significantly smaller models (from 20M to 60M params), not related to natural language generation. However I got encouraging results and I believe GLA should be competitive, but so far I fail to match RWKV/Mamba, despite promising speed/VRAM usage.

I have multiple question in order to integrate GLA correctly :

  1. What is your advice on parameters choice for scaling GLA Transformer down to ~20-60M parameters ? In terms of layers/dimension/heads ?
  2. Can you confirm that my interpretation of a GLABlock is correct ?
class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.p_in = nn.Linear(d_model, d_model*8//3)
        self.p_out = nn.Linear(d_model*4//3, d_model)
    def forward(self, x):
        gate, x = self.p_in(x).chunk(2, dim=-1)
        return  self.p_out(nn.functional.silu(gate) * x)

class GLABlock(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.att = GatedLinearAttention(d_model, heads)
        self.ffn = SwiGLU(d_model)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, x):
        y = self.att(self.ln(x)) + x
        y = self.ffn(y)
        return y
sustcsonglin commented 10 months ago

self.ffn(y) should be self.ffn(self.ln2(y)).
What is the performance gap? For smaller model, the D_model is small, and then D_head is small. We would suggest using a small number of head, e.g., 1, to make D_head >= 64.

theodorblackbird commented 10 months ago

self.ffn(y) should be self.ffn(self.ln2(y)).

I didn't get it. Interestingly this alone seems to make it competitive. I will keep investigating now.

Capture d’écran du 2023-12-17 19-35-40

Thank you.

sustcsonglin commented 10 months ago

Seems that we missed something in the paper. I checked our code implementation and it has two layernorms like Transformers. I am interested in the exact number on the performance gap. Our smallest experimental scale is 350m. My general sense is that for smaller model, token mixing is more important, so you might want to try out the parameter allocation used in RetNet, i.e., D_k = D_model. D_v=2*D_model. Our finding is that for large-scale model, allocating more parameters to FFNs is more important.