ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.17k stars 873 forks source link

LoRA: Increased volatility of train loss #583

Open madroidmaq opened 7 months ago

madroidmaq commented 7 months ago

When using the latest LoRA training, the volatility of Loss became larger. When I further analyzed the cause, I suspected that it might have been introduced by commit #528 due to an internal reordering of the dataset.

https://github.com/ml-explore/mlx-examples/blob/e2205beb668abb2334a75f405de094436cd323fe/llms/mlx_lm/tuner/trainer.py#L78-L86

I revert this commit and retrain, and the loss curve returns to stability. The following is the situation during my training. The fluctuation in this part will reduce the accuracy of the final model by about 10%.

image

I'm not quite sure what the reason for this adjustment is, or how I should ignore this part of the logic for sorting the dataset in this case. If this part of the logic is really needed, the way I think of it is to enlarge it in the Dataset definition and allow adding your own data sets.

awni commented 7 months ago

The fluctuation in this part will reduce the accuracy of the final model by about 10%.

How did you measure that?

Indeed I added the sorting before batching because it reduces a lot of wasted computation on the padded sequences, so it will be faster. This might cause the training loss to have slightly higher variance but I did not expect there to be an impact on the final validation loss, but it would be good to make sure we measure that correctly.

Roughly how large is the dataset you are training on?

madroidmaq commented 7 months ago

I used a private data set for training, with a data set ratio of 8:1:1, and a total of about 10k. In the scenario I trained, the model would have a stable output (a structured URL instruction), so it was easy for me to calculate its accuracy.

Accuracy before changes: 51.98% (551/1060) Accuracy after changes: 41.60% (441/1060)

awni commented 7 months ago

Yikes that’s pretty bad. So couple ideas:

  1. we can have a no sort flag
  2. Could try to make batch content more random.
  3. could revert the sorting for the time being

I will play around with some options. If you have time to explore a bit that would also be helpful!

madroidmaq commented 7 months ago

Yikes that’s pretty bad. So couple ideas:

  1. we can have a no sort flag
  2. Could try to make batch content more random.
  3. could revert the sorting for the time being

I will play around with some options. If you have time to explore a bit that would also be helpful!

I'd be happy to take a stab at this section and test it if you have some changes to make.

  1. adding sort flag is a quick solution.
  2. random the dataset is also a good idea, I had a try on this part and there was roughly 1% accuracy improvement (52.877% -> 53.774%). But there are some fluctuations in the randomized dataset, as follows: image

In addition to that, I'd like to know what you think of this part of the dataset formatting change (#548), and whether this part of the submission is something that will be accepted (or whether it's simple) in your opinion. After I merge this part (and of course, not necessarily relying on this part of the tweak), I'll try to reproduce the training approach in the paper Rethinking Optimization and Architecture for Tiny Language Models , dynamically adjusting the learning rate as well as the batch sizes when fine-tuning. My current idea of how to dynamically adjust the batch size is based on the dataset, as it can be decoupled from the rest of the fine-tuning.

awni commented 7 months ago

Thanks!! I will take a look at #548 shortly, sorry for the delay!

awni commented 7 months ago

@madroidmaq did you have any time to investigate this? I am hoping to come back to it and figure out a better batching strategy.

madroidmaq commented 7 months ago

@awni Synchronizing some of my recent attempts, so far the logic for sorting has been the least effective, other than that I've tried random sorting and sorting by label (I have the corresponding lebel information in my data).

Here is some data from my tests, the final accuracy and the loss curve.

Accuracy Loss
image image

I don't have any other validation optimization ideas on my side, I can submit the local logic for adding the sort flag to PR first, and then tweak it if there are other better ways to deal with it in the future.

awni commented 7 months ago

Thanks for the update.. it might be best to simply disable sorting and compile in LoRA for now :(. It is a modest but nice speed improvement so it's a shame, but it clearly requires some more work to get right.

madroidmaq commented 7 months ago

Synchronizing with the bad news, I rebase the latest code and test it and find that the loss curve changes similarly to the previous one, and the way we I used the flag is indeed working. However, when I tested the accuracy, I found that there was a significant drop in both sets of data, and the difference between the 2 became very small. I'm not quite sure what's causing this at the moment, and I'll further rule out recent code commits as having an effect on this.

The dashed line shows the training results of the code before rebase, and the solid lines show the results of passing in true and False via flag.

image

Accuracy also dropped from 54% to around 18%. image

awni commented 7 months ago

That doesn't look so good. What version of MLX are you using and what commit for MLX LM?

awni commented 7 months ago

@madroidmaq make sure you are using the latest MLX (0.8 or building from source). That's pretty important otherwise you will go through a bad path for RMS Norm (it won't accumulate in the right precision).

angeloskath commented 7 months ago

@madroidmaq a couple of weeks ago we changed the default dropout from 0.05 to 0. Could this be the issue?

I am trying to reconcile in the lower training and validation loss with the worse accuracy... Perhaps resetting it to 0.05 in your training config would be something to check.

madroidmaq commented 7 months ago

@awni I've tried this with both versions 0.6 and 0.8 of MLX and it works about the same, neither is good. I'm using the most current commit fbed720

@angeloskath I'll adjust the dropout to 0.05 and try again.

madroidmaq commented 7 months ago

After adjusting the dropout to a value other than 0, the checkpoints file cannot be saved properly, and the following error message is reported:

Iter 200: Train loss 1.258, Learning Rate 8.500e-06, It/sec 2.376, Tokens/sec 2526.571, Trained Tokens 220420, Peak mem 9.175 GB
Iter 200: Val loss 1.337, Val took 9.117s
Iter 200: Saved adapter weights to build/checkpoints/200_adapters.npz.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/madroid/develop/workspace/mlx-examples/llms/mlx_lm/lora2.py", line 227, in <module>
    train(
  File "/Users/madroid/develop/workspace/mlx-examples/llms/mlx_lm/tuner/trainer.py", line 320, in train
    lvalue, toks = step(batch)
                   ^^^^^^^^^^^
IndexError: unordered_map::at: key not found

I don't know much about this part, should I if I solve this problem, can you provide some ideas that I can try further.

awni commented 7 months ago

Oof sorry that’s the compile, you can just remove it for now as in #608

angeloskath commented 7 months ago

Hm that is interesting because I just trained with Dropout enabled. I assume this means you are not on main or on fast_rms on top of mlx-examples.

madroidmaq commented 7 months ago

After I followed the tweaks in #608, it works fine. I tried the dropout parameter and it did improve, but not to the original accuracy. Here are the numbers with the dropout adjusted to 0.05 and 0.1 respectively:

image

In addition to that, I rolled back the code to before #528 was committed and the accuracy was also around 30%. So it's possible that there's something wrong with my local code, and I'll look into it further. Currently my code logic is somewhat coupled with the code in the project, I'll spend some time decoupling them, and will probably also follow up with tweaks to the APIs in the current project so that refactoring can go smoothly.

But this shouldn't affect commit #611 on sort flag.