tysam-code / hlb-CIFAR10

Train to 94% on CIFAR-10 in <6.3 seconds on a single A100. Or ~95.79% in ~110 seconds (or less!)
https://twitter.com/hi_tysam
Apache License 2.0
1.21k stars 75 forks source link

Push the grand, hotly-anticipated <10 seconds update! (down over 2.3 seconds total!!!) #4

Closed tysam-code closed 1 year ago

tysam-code commented 1 year ago

Here is an overview of the changes. Remember, none of this so far is in JIT (!!!!), so things should be really snappy

Changes

-- Misc extensive hyperparameter tuning (lr<->final_lr_ratio were especially sensitive to each other) -- Added squeeze-and-excitation layers (very effective, might be faster with Pytorch 2.0) -- Converted the whitening conv from 3x3->2x2. This significantly increased speed and resulted in some accuracy loss, which hyperparameter tuning brought back -- With the whitening conv at 2x2, we could now set the padding to 0 to avoid an expensive early padding operation. This also made the rest of the network faster at the cost of accuracy due to the spatial outputs being slightly smaller -- The samples for the whitening conv is the whole dataset now. To be friendlier to smaller GPUs (8 GB or so, I think), we process the whitening operation in chunks over the whole dataset now -- We scale the loss before and after summing since with float16 that is a regularizing operation, and it was regularizing slightly too strongly for our needs. --We unfortunately had to bring another large timesave/accuracy boost off the shelf to make this update fly under 10 seconds (the first being the 3x3->2x2 conv conversion), and that was replacing the CELU(alpha=.3) activation functions with the now-reasonably-standard GELU() activations. They perform extremely well and the python kernel is very fast for the boost that the activation provides. What's not to like?

If you'd like to follow our progress on our journey to our goal of training to 94% on CIFAR10 on a single GPU in under 2 seconds within maybe (hopefully) 2 years or so, don't forget to watch the repo! We're in the phase where the updates are starting to get harder to put out, but there's still potential for a few good, "quick" chunks to be optimized away yet.

Further discussion

We've noted that the space of hyperparameters that perform optimally for a given problem grows sharper and sharper as we approach optimal performance, similar to the top noted in https://arxiv.org/pdf/2004.09468.pdf. Much of this update involved the extremely laborious task of tuning many hyperparameters within the code, which was done manually partially out of laziness, and partially because it's in the best interest of future me to have an instinctive feel for how they interact with each other. Unfortunately most of the hyperparameter numbers are not clean powers of 2 anymore, but we eventually did have to break that particular complexity barrier.

We performed some preliminary scaling law experiments, and we find that indeed, only increasing the network base width and training epochs yields a good scaling of performance -- twiddling the hyperparameters for these longer runs seems to decrease performance (outside of the EMA). In our runs, we got an average of 95.74% with depth 64->128 and epochs 10->80 (Final EMA percentages: 95.72, 95.83, 95.76, 95.99, 95.72, 95.66, 95.53). We're in the regime now where the scaling laws seem to hold very clearly and cleanly, generally speaking, woop woop woop! :O :D <3 <3

Feel free to open an issue or reach out if you have any questions or anything else of that nature related to this particular work, my email here is hire.tysam@gmail.com. Many thanks, and I really appreciate your time and attention!