bkitano / llama-from-scratch

Llama from scratch, or How to implement a paper without crying
https://blog.briankitano.com/llama-from-scratch/
482 stars 46 forks source link

Incorrect RMSNorm #4

Open arunmallya opened 4 months ago

arunmallya commented 4 months ago

The RMSNorm implementation in this codebase in wrong as it computes the RMS over the (T, D) dimensions instead of the (D) dimension. Assume input x is of shape (B, T, D).

The current code does this:

# x is (B, T, D).
ff_rms = torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5  # (B,).
raw = x / ff_rms.unsqueeze(-1).unsqueeze(-1)  # (B, 1, 1).

The original RMSNorm is here - https://github.com/meta-llama/llama/blob/main/llama/model.py#L34-L77

x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

The correct version using Frobenius norm would be:

ff_rms = torch.linalg.norm(x, dim=-1, keepdims=True) / math.sqrt(x.shape[-1])  # (B, T, 1).
raw = x / (ff_rms + eps)

Normalization should be per-token, not per-sequence.

nkkbr commented 2 months ago

I agree with you.

nkkbr commented 2 months ago

My version:

class RMSNorm(nn.Module):
    def __init__(self,layer_shape,eps=1e-8,bias=False):
        super(RMSNorm,self).__init__()
        self.register_parameter('scale',nn.Parameter(torch.ones(layer_shape)))
        self.eps=eps

    def forward(self,x):
        """
        assumes shape is (batch,seq_len,d_model)
        """
        f = torch.rsqrt((torch.mean(pow(x,2),dim=-1,keepdim=True)+self.eps))
        return x*f*self.scale[:x.shape[1],:].unsqueeze(0)
bkitano commented 1 month ago

hi! open a PR?