Closed zcbenz closed 1 month ago
- MLP's activation function should be
nn.gelu_approx
instead ofnn.gelu
.
I switched it to nn.gelu_approx
in #1062. I didn't see any noticeable difference in the generations with a quantized model.. but it seems safer so why not.
- The
h = h * (self.args.hidden_size**0.5)
normalization needs casting toh.dtype
.
MLX will correctly cast the normalizing scale to `h.dtype
- The
1.0 + self.weight
inRMSNorm
needs casting to float32.
RMSNorm
(and most normalization layers in MLX) accumulate the statistic (mean, var, etc) in float32
which is the reason one casts before calling them. So upcasting before norm terms is in general not be necessary.
Thanks for the insights! It is amazing MLX dodged the problems by good design.
I came across https://github.com/huggingface/transformers/pull/29402 and I think it applies to MLX's gemma implementation too.
It has 3 possible problems:
nn.gelu_approx
instead ofnn.gelu
.h = h * (self.args.hidden_size**0.5)
normalization needs casting toh.dtype
.1.0 + self.weight
inRMSNorm
needs casting to float32.I think MLX casts scalar to array's dtype in binary ops so they may not all apply.