lucidrains / MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
MIT License
620 stars 52 forks source link

Training Results and Scaling #12

Open MiscellaneousStuff opened 1 year ago

MiscellaneousStuff commented 1 year ago

Hi there.

I’ve run the training code in this repository for 25k out of the 100k batches and achieved a validation loss of around 1.28, or perplexity of 3.59. After this, the training loss continues to drop but the validation loss either plateaus, or slowly starts going back up. I was curious if you also found the same (however, I stopped at 25k and restarted training. I reloaded the model and optimiser checkpoints but didn’t preserve train/val shuffling. Not sure if this confounds it either). Also I tried running the training on a H100 80GB VRAM with a batch size of 60 instead of 4 and found very slow convergence and an earlier plateau of the val loss (~2.5 ish). Do other hyperparameters need to be adjusted to scale training on larger devices? I originally tested on an RTX 3060Ti with 8GB VRAM.

Thanks in advance.

ChukwumaChukwuma commented 1 year ago

Hi there,

Thanks for your question. It sounds like you are experiencing a common problem with language model training, which is the early plateau problem. This is where the validation loss stops improving after a certain number of epochs, even though the training loss continues to decrease.

There are a few possible reasons for this problem. One possibility is that the model is overfitting to the training data. This can happen if the model is too complex or if the training data is not diverse enough. Another possibility is that the learning rate is too high. This can cause the model to jump around the loss landscape, making it difficult to converge.

In your case, it is possible that the model is overfitting to the training data. This is because you are using a relatively small batch size (4) on a large dataset. This means that the model is seeing the same examples over and over again, which can make it more likely to overfit.

You can try to address the early plateau problem by doing the following:

If you are still experiencing the early plateau problem after trying these suggestions, then you may need to increase the size of your dataset. This will give the model more data to learn from and help it to generalize better to new data.

As for your question about scaling training on larger devices, the answer is yes, other hyperparameters may need to be adjusted. For example, you may need to increase the batch size and learning rate. You may also need to use a different optimizer, such as AdamW.

I hope this helps!