ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.3k stars 898 forks source link

Possible Gemma bugs #1057

Closed zcbenz closed 1 month ago

zcbenz commented 1 month ago

I came across https://github.com/huggingface/transformers/pull/29402 and I think it applies to MLX's gemma implementation too.

It has 3 possible problems:

I think MLX casts scalar to array's dtype in binary ops so they may not all apply.

awni commented 1 month ago
  • MLP's activation function should be nn.gelu_approx instead of nn.gelu.

I switched it to nn.gelu_approx in #1062. I didn't see any noticeable difference in the generations with a quantized model.. but it seems safer so why not.

  • The h = h * (self.args.hidden_size**0.5) normalization needs casting to h.dtype.

MLX will correctly cast the normalizing scale to `h.dtype

  • The 1.0 + self.weight in RMSNorm needs casting to float32.

RMSNorm (and most normalization layers in MLX) accumulate the statistic (mean, var, etc) in float32 which is the reason one casts before calling them. So upcasting before norm terms is in general not be necessary.

zcbenz commented 1 month ago

Thanks for the insights! It is amazing MLX dodged the problems by good design.