Open danbraunai opened 2 months ago
I looked into this. I believe it just boils down to the eps
parameter. When I pass eps=1e-6
to F.rms_norm
, both implementations give the same result.
To test this, I temporarily modified forward
in LlamaRMSNorm
in train_llama.py
to be the following:
def forward(self, hidden_states):
print(F.rms_norm(hidden_states, [768], self.weight, eps=1e-6))
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
result = self.weight * hidden_states.to(input_dtype)
print(result)
exit()
return result
Note the inclusion of the eps
parameter passed to F.rms_norm
. If that is omitted, we see the issue:
tensor([[[ 0.3514, -1.0652, -0.9252, ...
tensor([[[ 0.3510, -1.0641, -0.9242, ...
But if we include it, everything lines up:
tensor([[[ 0.3510, -1.0641, -0.9242, ...
tensor([[[ 0.3510, -1.0641, -0.9242, ...
I believe that we can feel confident adopting either implementation.
Beautiful. Thanks for solving this! I'm fine with using either version (torch is shorter but it's somewhat nice having the code be explicit). Feel free to submit a PR to switch to torch (with eps=1e-6 so that our implementation matches llama better for #2) if you have a preference for it.
Happy to help! I'm inclined to agree that having the explicit implementation is nice. So my recommendation is to keep the current implementation.
Huggingface RMSNorm and Torch RMSNorm give slightly different values (=0.0029 on one input). Unclear if there's an issue here. Note that our RMSNorm uses the version from HF.
If we can't verify why we get different values and match them, I suppose we can keep the HF implementation as we'll be verifying our model by comparing to the HF implementation.