HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Reduce Compile-Time #26

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Currently, our models take a while to compile.\ Compiling a model of 16 layers on a v3-8 takes almost 15 minutes: grafik Adding the GPT-2 tokenizer adds another 200s runtime: grafik Using 64 layers (with a char tokeniser) increases the compile time by over 17x to ~4h: grafik In the future, we'd need our models to compile within a few minutes to ensure that we spend most of our runtime in steps. Especially with hyperparameter sweeps where each run lasts up to 16 hours, a 4-hour compile time is prohibitively long.\ This issue discusses possible approaches to reduce compile-time and implement and benchmark them.

ClashLuke commented 2 years ago

Unfortunately, combining MoE and convolution does not seem possible on TPU without unfolding the inputs, computing all kernels explicitly, or otherwise using enormous amounts of memory and compute. That's why we need to figure out what to do with these convolutional blocks. Replacing them with pointwise operations does not seem sensible, as pointwise operations use significantly more memory while using having fewer flops, resulting in an overall reduced capacity. In previous experiments, we noticed this capacity difference is visible after just a week of training, where models with greater kernels outperformed models with tiny kernels.\ Removing these convolutions is not an option, so we will have to figure out how to use them properly.

One way to reduce compile time would be to run the same computation at every layer. This could be done by using weight sharing or by always referencing the same weight tensor and slicing out the parameters of the current layer.

Unfortunately, previous experiments showed that naively stacking weights results in enormous memory allocations, as gather and scatter operations are required to retrieve each parameter and return its gradients. At the same time, we know that naively sharing all weights results in poor performance that does not justify the reduced memory footprint. Lastly, we also know that the computation time seems to increase quadratically with the depth of the model and is at 15 minutes for 16 layers.

That's why I propose to use N (=16) blocks with their own layers within a lax.scan, which would allow us to cap the compilation time at 15 minutes. However, I don't know if the best way of integrating scan is by sharing weights between different invocations of these 16 layers or if we should instead aim to stack weights locally and hope that it already reduces memory overheads. Another option would be to combine these in one way for some parameters (for example, stacking all normalisation parameters) while doing it differently for others.

ClashLuke commented 2 years ago

We can scan over the weights and jax should hopefully stack and slice them efficiently.

ClashLuke commented 2 years ago

Addressed by #73