Closed ClashLuke closed 1 year ago
Shampoo improves stability:
Keeping it.
This was probably the most difficult thing to find:
We used different inputs on each model-parallel device in all previous attempts (with the sharded Dataloader).
Convergence works on v3-32s. Current training model with 16 blocks (2.5B) on one preemptible v3-32. ToDo:
It still diverges. Trying alternatives to gradient descent now
HeavyBall-Adam improves convergence a good bit compared to pure Adam (same speed):
Nesterov-HeavyBall-Adam improves convergence further (same speed) and allows grafting onto it (as beta2 will never be 0 -> step size won't go to infinity)
Reducing batch size and LR decreases final accuracy
Increasing depth helps with convergence per step
While decreasing convergence per time
ReZero helps in the early training (unknown if it'll stay the same during the late stages)
Maximizing the context size (with the same number of tokens) increases batch-to-batch variance, likely because it's seeing so many different datasets.
Potential solution: Train short for a few epochs and "fine-tune" a longer context onto the model, as in BERT and Shortformer
SandwichLN (green) outperforms no output norm (grey), and HeavyBall (green) outperforms Nesterov-HeavyBall (blue):
Non-HeavyBall Adam diverges, likely due to a small square moment.
Similarly, with a depth of 16 (above has a depth of 1), ReZero (purple) outperforms no output normalization (blue) but doesn't converge as well as SandwichLN (red):
for some reason, increasing the width of the network still hurts convergence
Losses finally match
Model finally converges and v3-32 seems marginally better than v3-8:
Not trying to match the new v3-8 loss with the old loss.
The main difference likely is the 0-init in the output embed which helps improve stability
Additionally, the data sampling is improved, so the convergence is slower and we'll have less spikes (as it cant overfit to a subsection)
The core problem seems to be that it doesnt converge faster with more devices
New scan+psum_scatter+all_gather-based gradients are faster and more accurate, yielding the same (or better) loss curve as with the non-parallel baseline loss:
Now onto fixing the difference between the old (blue) and new (orange/grey) code.
Together with LeakyReLU we're back at the old performance:
Testing Mish now.
Works, but v3-32 still doesn't work better than v3-8
Possibly later?
LeakyReLU (0.01) > Mish > LecunTanh, and even the bumps are still matched
I looked over the differences once again and solved all deepsource issues except for missing docstrings. Additionally, this branch introduces some pytests to ensure the custom gradients are actually correct this time (they weren't before).\
Merging, as the state is stable and significantly better than the current main
.\
Convergence with v3-32s needs to be addressed in a separate branch.
Observation: 1) Adam gets unstable at scale, as gradients become smaller and sparser. This causes beta2 to decay to 0. However, once one gradient spike happens, beta1 will shoot up almost immediately, causing a large number to get divided by
eps
which disturbs the gradients a lot. Current solution: SM3. Looking into MADGRAD and Mirror-MADGRAD to fix this. Untested if Shampoo is affected as well. 2) Distributed Dataloaders are difficult but possible. Current solution: Only one Dataloader and no input stacking. 3) Small-Init + input norm causes significant instability (of the bad kind, see Figure 1). Solution: Revert to normal init. 4) FP64 in Jax works out the box, but there were missing casts in Shampoo. Now running the entire optimizer in FP64. 5) Dying ReLU is a real problem. Many nodes simply die (or get their weights down to 0) and they will rarely be activated again. Also, LeakyReLU performs worse than Swish/GeLU/Mish/LiSHT. Switching to LeCunTanH which doesn't have these problems, still allows for inversion and gives similar or better results on toy datasets. New Problem: Pruning won't be possible.Figure 1:![grafik](https://user-images.githubusercontent.com/39779310/184473761-053aa749-73a9-4016-b070-46b1f5bd7923.png)