danbraunai / simple_stories_train

Trains small LMs. Designed for training on SimpleStories
3 stars 1 forks source link

Verify RMS norm #6

Open danbraunai opened 2 months ago

danbraunai commented 2 months ago

Huggingface RMSNorm and Torch RMSNorm give slightly different values (=0.0029 on one input). Unclear if there's an issue here. Note that our RMSNorm uses the version from HF.

If we can't verify why we get different values and match them, I suppose we can keep the HF implementation as we'll be verifying our model by comparing to the HF implementation.

ThomasWMarshall commented 2 months ago

I looked into this. I believe it just boils down to the eps parameter. When I pass eps=1e-6 to F.rms_norm, both implementations give the same result.

To test this, I temporarily modified forward in LlamaRMSNorm in train_llama.py to be the following:

def forward(self, hidden_states):
    print(F.rms_norm(hidden_states, [768], self.weight, eps=1e-6))
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    result = self.weight * hidden_states.to(input_dtype)
    print(result)
    exit()
    return result

Note the inclusion of the eps parameter passed to F.rms_norm. If that is omitted, we see the issue:

tensor([[[ 0.3514, -1.0652, -0.9252,  ...
tensor([[[ 0.3510, -1.0641, -0.9242,  ...

But if we include it, everything lines up:

tensor([[[ 0.3510, -1.0641, -0.9242,  ...
tensor([[[ 0.3510, -1.0641, -0.9242,  ...

I believe that we can feel confident adopting either implementation.

danbraunai-apollo commented 2 months ago

Beautiful. Thanks for solving this! I'm fine with using either version (torch is shorter but it's somewhat nice having the code be explicit). Feel free to submit a PR to switch to torch (with eps=1e-6 so that our implementation matches llama better for #2) if you have a preference for it.

ThomasWMarshall commented 2 months ago

Happy to help! I'm inclined to agree that having the explicit implementation is nice. So my recommendation is to keep the current implementation.