Closed aboros98 closed 4 months ago
Hello!
In the PyTorch implementation, in the MLP, exact GeLU is used as a gating function.
In the JAX version, the approximate gelu is used.
Could you please clarify which version is the correct one?
I see this PR is fixing it, will help land it soon. https://github.com/google/gemma_pytorch/pull/37
PR is checked in. Closing this.
Hello!
In the PyTorch implementation, in the MLP, exact GeLU is used as a gating function.
In the JAX version, the approximate gelu is used.
Could you please clarify which version is the correct one?