Closed ClashLuke closed 1 year ago
Checkpoint+Resume is currently broken for MoE, which would indicate that the device IDs before and after training aren't necessarily the same.\ This is most likely because psum_scatter and device ids are net guaranteed to be aligned. As a workaround, psumscatter IDs as used during training should be used for indexing.
There were some issues with resuming from a checkpoint in this branch, but they are solved by https://github.com/HomebrewNLP/HomebrewNLP-Jax/commit/13feeb9ad4550b17c6eb2e194f8a62f3ce2179a1. Retrying MoE now.
Significantly better than baseline (3% and still converging):
But only 30% slower:
Most importantly, the MoE model reaches a significantly lower loss. However, it also uses 9.6B parameters which puts the memory usage at 14GB (out of 16GB). Therefore, to use this without OOMs at 60M+ dense-equivalent parameters per device, #88 has to be merged first.
Achieved 2% lower loss than fully converged baseline with increased stability
7% lower loss, and still converging:
At the cost of 20% slower steps:
And 8x as many parameters: