HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Initialize deep model from shallow model #43

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

We already support the usage of pretrained input embeddings. However, output embeddings and layers still have to be retrained. One way to use smaller checkpoints when training larger ones (if comparing loss curves doesn't matter) would be to initialise the larger model from the weights of the smaller model by replicating them. As our models always have a fixed width for a given number of devices, loading the checkpoint of a shallower model would be as easy as converting input_embedding-layer1-layer2-output_embedding to input_embedding-layer1-layer2-layer1-layer2-output_embedding. This issue aims to track the progress of such a scheme and achieve faster convergence by effectively skipping the loss of the first thousand steps.

ClashLuke commented 2 years ago

Solved by #51