microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.25k stars 90 forks source link

integration with Flax? #11

Open nestordemeure opened 2 years ago

nestordemeure commented 2 years ago

Is there any interest in integrating this work with Flax?

They already have a init function, decoupling parameters initialization from model definition which could make introducing mup fairly plug-and-play.

Plus they relie on optax for their optimizers. As that library has a focus on composability, you might be able to introduce a transformation that takes an optimizer and makes it mup compatible.

Overall, I believe the Flax ecosystem could make mup more easily accessible to people.

thegregyang commented 2 years ago

Integration with Flax would be fantastic, but neither I nor @edwardjhu are familiar with it. If someone from the Flax team can work with us, we can definitely advise the integration process.

davisyoshida commented 2 years ago

@nestordemeure In case you're interested, I have a first draft of a port to JAX/Haiku here. If you're not attached to FLAX in particular you could use this. You could also probably adapt this design to FLAX if you wanted, since FLAX/Haiku are more similar than FLAX/torch.

Edit: @thegregyang By the way, can you take a look at the plots in the README there? The optimal learning rate stabilizes with width, but it does look like I see better training loss for SP sometimes. Is that indicative of a bug? My coord checks look good, nothing grows with width, output norm (at init) decays with width.

thegregyang commented 2 years ago

Hey @davisyoshida your repo looks great so far!

For your plot, you'd get better results if you tune the input, output, and hidden learning rates for your small model and scale up from there, sweeping a global lr multiplier on the x-axis (ideally, you tune (lr, init) for all parameter tensors, but these 3 learning rates should be a good practical approximation). In particular, for a fair comparison, the curves for your small model in both SP and muP plots should be the same. Your current plots are just looking at a slice of the HP space (of (lr, init) for all parameter tensors) away from the true optimum.

davisyoshida commented 2 years ago

Ah that makes perfect sense, I'll generate new versions of the figures. Thanks!