google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

Feature request: Mixture of Experts example #4034

Closed SamKG closed 5 days ago

SamKG commented 6 days ago

Hello,

I am wondering if there are any examples which use Flax (or just pure Jax) for mixture of experts models. I'd be happy to contribute one myself if there aren't any - just wondering if anyone has done the heavy lifting already.

SamKG commented 5 days ago

found one here: https://github.com/google/flax/discussions/4035