In mlx/nn/layers/activations.py, the approximation options denoted by the approx argument are 'none'(default), precise and fast. In PyTorch API the precise approximation is denoted by tanh,
$\textrm{GELUApprox}(x) = 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)\right)\right)$
We could change line 555 to:
approx ('none' | 'precise'/'fast' | 'fast'): Which approximation to gelu to use if any.
(and add tests where relevant) to allow for matching the torch API and easier code migration, if that is within the scope of MLX. I can add tests and open a PR if that's something we want to do?
In mlx/nn/layers/activations.py, the approximation options denoted by the
approx
argument are'none'
(default),precise
andfast
. In PyTorch API theprecise
approximation is denoted bytanh
, $\textrm{GELUApprox}(x) = 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)\right)\right)$ We could change line 555 to:line 561 to:
and line 547 to:
(and add tests where relevant) to allow for matching the torch API and easier code migration, if that is within the scope of MLX. I can add tests and open a PR if that's something we want to do?