HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Transfer weights across size #51

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Current experiments indicate significantly improved convergence with https://github.com/HomebrewNLP/HomebrewNLP-Jax/commit/bfe80d462fead7195b8a048bdfa6e778bac97048 grafik

Another run based on https://github.com/HomebrewNLP/HomebrewNLP-Jax/pull/51/commits/935c543ba63fd643f120a69f577d9fa4d99e94ed should start soon enough

ClashLuke commented 2 years ago

Initializing normalization to 0 is necessary and helps the model quite a lot (purple) compared to the baseline (blue): grafik While the baseline stops converging, the model with normalization initialized to 0, and no other difference continues to converge as expected: grafik

This stop might be because the model with random initialization has to learn not to use the new blocks, pushing all their weights towards zero. Once the optimizer converged into a state where it pushes the weights towards 0, it's challenging to move out of that space again.\ On the other hand, if we initialize the normalization to 0, the gradients will push in random directions forcing the model to get used to all layers slowly. This change in gradients might be why ReZero works better on paper.

Overall, the transfer seems to work fine, so the PR can be merged once the new model has outperformed the one it's transferring.

ClashLuke commented 2 years ago

The transferred model (purple) converges to a slightly lower loss than the baseline (green). grafik grafik This improved convergence implies that the model learns to use the new layers, which we were aiming to achieve.