axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.64k stars 838 forks source link

Shuffling the data across all epochs #735

Open BugReporterZ opened 11 months ago

BugReporterZ commented 11 months ago

⚠️ Please check that this feature request hasn't been suggested before.

πŸ”– Feature description

It is widely observed that finetuning multiple epochs on the same dataset, no matter how shuffled, will produce stepped train/eval loss curves. This has been recently attributed to memorization, as for example in these blog posts:

It would seem, in short, that at the end of each epoch the model may become too confident and show discontinuities in the loss graphs as the new one begins. Alternatively, it might also be that the model is learning undesirable macro-scale patterns from the consecutive epochs it's seeing during training.

βœ”οΈ Solution

Instead of shuffling the data once and then make the model iterate several times over that data, a possible solution can be repeating the data first, and then shuffling the data across all epochs. This appears to prevent stepwise loss curves and probably give a fairer/non-biased representation of how the model is learning from the data.

The complete algorithm could be as follows:

Of course, this alternative shuffling method could be made optional.

❓ Alternatives

An alternative to introducing a separate multi-epoch shuffling method would be allowing users to specify a separate evaluation dataset. In this way, a training dataset with repeated data could be provided by the user, and the evaluation dataset (normally split from the training data in Axolotl) wouldn't be affected by this repetition.

πŸ“ Additional Context

Repeating the data at the dataset level and then shuffling it confirms that the train loss graph will decrease smoothly, without steps. Here follows a screenshot of a dataset first duplicated 3 times (what would normally be 3 epochs), shuffled and finally trained for 1 epoch in Axolotl.

Unfortunately since by just doing this the eval split ends up including repeated data, it wouldn't be meaningful to also include the eval loss curve here.

image

Acknowledgements

winglian commented 11 months ago

@BugReporterZ just curious, what did the eval loss look like?

BugReporterZ commented 11 months ago

It looked like this, decreasing at an almost constant rate, which seems unrealistic:

image

enn-nafnlaus commented 10 months ago

I cannot imagine that this would prevent memorization.

Train loss always goes down; the stepwise spike is in eval_loss, not train_loss, so seeing the train loss go down doesn't mean anything.

Anyway, if you want to prevent memorization, we have dropout now :)