HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 5 forks source link

Scale #71

Closed ClashLuke closed 1 year ago

ClashLuke commented 1 year ago

Observation: 1) Adam gets unstable at scale, as gradients become smaller and sparser. This causes beta2 to decay to 0. However, once one gradient spike happens, beta1 will shoot up almost immediately, causing a large number to get divided by eps which disturbs the gradients a lot. Current solution: SM3. Looking into MADGRAD and Mirror-MADGRAD to fix this. Untested if Shampoo is affected as well. 2) Distributed Dataloaders are difficult but possible. Current solution: Only one Dataloader and no input stacking. 3) Small-Init + input norm causes significant instability (of the bad kind, see Figure 1). Solution: Revert to normal init. 4) FP64 in Jax works out the box, but there were missing casts in Shampoo. Now running the entire optimizer in FP64. 5) Dying ReLU is a real problem. Many nodes simply die (or get their weights down to 0) and they will rarely be activated again. Also, LeakyReLU performs worse than Swish/GeLU/Mish/LiSHT. Switching to LeCunTanH which doesn't have these problems, still allows for inversion and gives similar or better results on toy datasets. New Problem: Pruning won't be possible.

Figure 1: grafik

ClashLuke commented 1 year ago

Shampoo improves stability: grafik grafik

Keeping it.

ClashLuke commented 1 year ago

This was probably the most difficult thing to find: grafik

We used different inputs on each model-parallel device in all previous attempts (with the sharded Dataloader).

ClashLuke commented 1 year ago

Convergence works on v3-32s. Current training model with 16 blocks (2.5B) on one preemptible v3-32. ToDo:

ClashLuke commented 1 year ago

It still diverges. Trying alternatives to gradient descent now grafik

ClashLuke commented 1 year ago

HeavyBall-Adam improves convergence a good bit compared to pure Adam (same speed): grafik grafik

Nesterov-HeavyBall-Adam improves convergence further (same speed) and allows grafting onto it (as beta2 will never be 0 -> step size won't go to infinity) grafik grafik

ClashLuke commented 1 year ago

Reducing batch size and LR decreases final accuracy grafik grafik

ClashLuke commented 1 year ago

Increasing depth helps with convergence per step grafik While decreasing convergence per time grafik

ReZero helps in the early training (unknown if it'll stay the same during the late stages) grafik

Maximizing the context size (with the same number of tokens) increases batch-to-batch variance, likely because it's seeing so many different datasets. grafik Potential solution: Train short for a few epochs and "fine-tune" a longer context onto the model, as in BERT and Shortformer

ClashLuke commented 1 year ago

SandwichLN (green) outperforms no output norm (grey), and HeavyBall (green) outperforms Nesterov-HeavyBall (blue): grafik Non-HeavyBall Adam diverges, likely due to a small square moment.

ClashLuke commented 1 year ago

Similarly, with a depth of 16 (above has a depth of 1), ReZero (purple) outperforms no output normalization (blue) but doesn't converge as well as SandwichLN (red): grafik

ClashLuke commented 1 year ago

for some reason, increasing the width of the network still hurts convergence grafik

ClashLuke commented 1 year ago

Losses finally match grafik

ClashLuke commented 1 year ago

Model finally converges and v3-32 seems marginally better than v3-8: grafik Not trying to match the new v3-8 loss with the old loss.

ClashLuke commented 1 year ago

The main difference likely is the 0-init in the output embed which helps improve stability grafik Additionally, the data sampling is improved, so the convergence is slower and we'll have less spikes (as it cant overfit to a subsection)

ClashLuke commented 1 year ago

The core problem seems to be that it doesnt converge faster with more devices

ClashLuke commented 1 year ago

New scan+psum_scatter+all_gather-based gradients are faster and more accurate, yielding the same (or better) loss curve as with the non-parallel baseline loss: grafik

Now onto fixing the difference between the old (blue) and new (orange/grey) code.

ClashLuke commented 1 year ago

Together with LeakyReLU we're back at the old performance: grafik

Testing Mish now.

ClashLuke commented 1 year ago

Works, but v3-32 still doesn't work better than v3-8 grafik

Possibly later?

ClashLuke commented 1 year ago

LeakyReLU (0.01) > Mish > LecunTanh, and even the bumps are still matched grafik

ClashLuke commented 1 year ago

I looked over the differences once again and solved all deepsource issues except for missing docstrings. Additionally, this branch introduces some pytests to ensure the custom gradients are actually correct this time (they weren't before).\ Merging, as the state is stable and significantly better than the current main.\ Convergence with v3-32s needs to be addressed in a separate branch.