ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.81k stars 967 forks source link

[Feature Request] Adding option to follow PyTorch API in GELU initalization #1265

Closed AtakanTekparmak closed 2 months ago

AtakanTekparmak commented 3 months ago

In mlx/nn/layers/activations.py, the approximation options denoted by the approx argument are 'none'(default), precise and fast. In PyTorch API the precise approximation is denoted by tanh, $\textrm{GELUApprox}(x) = 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)\right)\right)$ We could change line 555 to:

elif approx == "precise" or approx == "tanh":

line 561 to:

f"The approximation should be in ['none', 'precise'/'tanh', 'fast'] but '{approx}' was given"

and line 547 to:

approx ('none' | 'precise'/'fast' | 'fast'): Which approximation to gelu to use if any.

(and add tests where relevant) to allow for matching the torch API and easier code migration, if that is within the scope of MLX. I can add tests and open a PR if that's something we want to do?

awni commented 3 months ago

I don't have a problem with adding this. If you send a PR we can take a look.