google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.19k stars 492 forks source link

Inconsistency between PyTorch and JAX implementation #33

Closed aboros98 closed 4 months ago

aboros98 commented 4 months ago

Hello!

In the PyTorch implementation, in the MLP, exact GeLU is used as a gating function.

image

image

In the JAX version, the approximate gelu is used.

image image

Could you please clarify which version is the correct one?

pengchongjin commented 4 months ago

I see this PR is fixing it, will help land it soon. https://github.com/google/gemma_pytorch/pull/37

pengchongjin commented 4 months ago

PR is checked in. Closing this.