HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 5 forks source link

Moe2 #85

Closed ClashLuke closed 1 year ago

ClashLuke commented 1 year ago

7% lower loss, and still converging: grafik grafik

At the cost of 20% slower steps: grafik

And 8x as many parameters: grafik

ClashLuke commented 1 year ago

Checkpoint+Resume is currently broken for MoE, which would indicate that the device IDs before and after training aren't necessarily the same.\ This is most likely because psum_scatter and device ids are net guaranteed to be aligned. As a workaround, psumscatter IDs as used during training should be used for indexing.

ClashLuke commented 1 year ago

There were some issues with resuming from a checkpoint in this branch, but they are solved by https://github.com/HomebrewNLP/HomebrewNLP-Jax/commit/13feeb9ad4550b17c6eb2e194f8a62f3ce2179a1. Retrying MoE now.

ClashLuke commented 1 year ago

Significantly better than baseline (3% and still converging): grafik

But only 30% slower: grafik

Most importantly, the MoE model reaches a significantly lower loss. However, it also uses 9.6B parameters which puts the memory usage at 14GB (out of 16GB). Therefore, to use this without OOMs at 60M+ dense-equivalent parameters per device, #88 has to be merged first.

ClashLuke commented 1 year ago

Achieved 2% lower loss than fully converged baseline with increased stability