Open ZhiyuLi-goog opened 1 week ago
Thanks for the fix Zhiyu. I'll leave for @RissyRan to review
Thanks Zhiyu for the fix! I am fine to add this normalization if this makes us convenient to compare weights. Could you add this top_k_weights /= top_k_weights.sum(-1, keepdims=True)
before this line weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)
? It was not used anywhere before that.
I think we also need to add this normalization to somewhere [here]https://github.com/AI-Hypercomputer/maxtext/blob/061abd8a375b6b279eacf0ae867e05e4e08cc360/MaxText/layers/linears.py#L488) for dropping. softmax_probs *= combined_expert_mask
Have you checked the benchmark scores with and without this normalization? using correct ckpt on 8x22b or 8x7b.
Thanks Zhiyu for the fix! I am fine to add this normalization if this makes us convenient to compare weights. Could you add this
top_k_weights /= top_k_weights.sum(-1, keepdims=True)
before this lineweights = self.reshape_and_update_weights(top_k_weights, top_k_indices)
? It was not used anywhere before that.
Done.
I think we also need to add this normalization to somewhere [here]
) for dropping.
softmax_probs *= combined_expert_mask
Have you checked the benchmark scores with and without this normalization? using correct ckpt on 8x22b or 8x7b.
Have checked the benchmark scores the results are the same with and without this normalization.
Additionally,added a numerical verification notebook. Currently able get ~0.1 tolerance in logits for both megablox=True or False.
@gobbleturk could you review it at your convenience?
@gobbleturk could you review it at your convenience?
Hi, @gobbleturk I need code owner's review as the final step of this PR, thank you!
In addition to the typo, looks like we are missing this normalization in maxtext, which is to re-scale top_k_weights to a normalized one where their sum equals 1. I was able to match layer outputs after adding the normalization. This normalization won't affect training/inferencing since this is a constant term (per each token), which won't change softmax probability. But we can still add it for better alignment.