ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.18k stars 873 forks source link

[Feature Request] Full-Tuning Example #297

Closed fakerybakery closed 4 days ago

fakerybakery commented 10 months ago

Hi, It would be great to have an example of finetuning Phi without LoRA or QLoRA. Thanks!

awni commented 10 months ago

Sure, we can add phi to our LoRA example! @mzbac already did some great work to merge Phi into the generation example. So from there it should be pretty straightforward. I will try and get to it shortly.

fakerybakery commented 10 months ago

Thanks! Is it possible to add an example of tuning Phi without LoRA (ie full-tuning)? Thanks!

N8python commented 10 months ago

I would love this. I try unfreezing the model but that just leads to NaN loss.

mzbac commented 10 months ago

Full fine-tuning can be tricky because it can easily lead to catastrophic forgetting. Even fully layer lora for phi2 could cause some issues. I'm not very sure if the mlx team would provide an easy example for that, but it would be great to get some insight from Awni.

yousifa commented 10 months ago

Sure, we can add phi to our LoRA example! @mzbac already did some great work to merge Phi into the generation example. So from there it should be pretty straightforward. I will try and get to it shortly.

Added support for it in https://github.com/ml-explore/mlx-examples/pull/302

awni commented 10 months ago

I would love this. I try unfreezing the model but that just leads to NaN loss. @mzbac makes good points. LoRA fine-tuning is much more stable for these large models especially on small datasets.

Nevertheless, the NaN is concerning... I suspect it's because you are fine-tuning in fp16 which won't work well (the range can easily overflow). You might have better luck using fp32 at least to start, but it will be expensive memory / compute wise.

N8python commented 9 months ago

I have the memory and the compute - how would I cast everything to fp32? just .asdtype?

awni commented 9 months ago

You can do something like:

module.update(tree_map(lambda p: p.astype(mx.float32), module.parameters()))
N8python commented 9 months ago

Wow thanks!

N8python commented 9 months ago

Somehow 64GB wasn't enough for a 7B float32 finetune... Adam eats up memory LOL.

N8python commented 9 months ago

Confirmed, SGD works.

N8python commented 9 months ago

But my god, float32 is brutal. 1/10th the speed of float16...

awni commented 9 months ago

It shouldn't be 1/10th.. that probably means it's swapping :. Unfortunately, fine-tuning in 32-bit precision is very memory hungry.. it's uncommon to use 32-bit even for pre-training with such large models.

You can try:

  1. Reduce the batch size
  2. Reduce the sequence length
N8python commented 9 months ago

How come mlx fails in 16-bit if most big models are pretrained that way? Is it cause it doesn't use bfloat16?

awni commented 9 months ago

People do float16 and bfloat16 but both cases (typically) require modifications to actually make full training work. bfloat16 is easier than float16, but it sill often won't work with a naive SGD replacing fp32 with bf16.

N8python commented 9 months ago

Got it. Thank you for the info!

tcapelle commented 9 months ago

Already, to train in FP16, you need a ton of memory; the 7B param model will need 14GB of memory in FP16 just to load the weights. The gradients would be as significant, so another 14GB, and depending on the optimizer, you may need to store extra parameters. For instance, Adams stores 2 extra parameters, so another 28GBs.... You also need to account of the overhead of computing the intermediate layers outputs...

dsedov commented 9 months ago

How exactly do you switch whole training / finetuning pipeline to FP16 or BF16. I can't fine any arguments that the optimizer take? I tried converting all the weights after the model is created to float16, but that didn't work.

awni commented 9 months ago

I tried converting all the weights after the model is created to float16, but that didn't work.

What exactly do you mean by "didn't work"? In general that should work so there is either a bug in your conversion or a bug in the code you were using to train. In some cases we accidentally upcast (which is a bug) so we have to make sure we aren't doing it.

dsedov commented 9 months ago

I'm trying to train an LLM using this model: https://github.com/ml-explore/mlx-examples/blob/main/llms/llama/llama.py I've duplicated the code and added this to convert the weights: def convert_to_f16(self): self.apply(self.weigths_to_f16)

def weigths_to_f16(self, m): m = m.astype(dtype=mx.float16)

Even with absolutely tiny learning rate 1-e6 I get nan in loss.

danilopeixoto commented 9 months ago

I am currently fine-tuning a quantized version of the Mixtral 8x7B model using a MacBook Pro M3 Max with 64GB of memory. However, I've encountered an issue where I'm seeing nan loss values emerging after a number of training iterations.

Command:

# max_seq_length=8192

python -m mlx_lm.lora --train --model models/mixtral-8x7b-v0.1-8bit-64g/ --data datasets/chat-instruct/ --steps-per-report 1 --steps-per-eval 15 --save-every 15 --iters 500 --lora-layers 16 --batch-size 2

Output:

Loading pretrained model
Total parameters 13231.789M
Trainable parameters 2.229M
Loading datasets
Training
Starting training..., iters: 500
Iter 1: Train loss 1.198, It/sec 0.008, Tokens/sec 9.854
Iter 1: Val loss 1.133, Val took 200.444s
Iter 2: Train loss 1.211, It/sec 0.012, Tokens/sec 15.892
Iter 3: Train loss 1.914, It/sec 0.014, Tokens/sec 4.979
Iter 4: Train loss 1.188, It/sec 0.023, Tokens/sec 19.030
Iter 5: Train loss 0.985, It/sec 0.013, Tokens/sec 9.173
Iter 6: Train loss 0.542, It/sec 0.007, Tokens/sec 10.756
Iter 7: Train loss 1.434, It/sec 0.011, Tokens/sec 4.589
Iter 8: Train loss nan, It/sec 0.035, Tokens/sec 50.523
Iter 9: Train loss nan, It/sec 0.076, Tokens/sec 42.769
Iter 10: Train loss nan, It/sec 0.023, Tokens/sec 65.747

I did cast the model to the bfloat16 data type, but this did not resolve the issue.

As @awni pointed out, training with 16-bit precision is prone to numerical stability issues and requires additional implementation.

Here's an article on mixed precision training that might provide additional insights: https://medium.com/@jbensnyder/solving-the-limits-of-mixed-precision-training-231019128b4b

My current solution involved casting float16 parameters to float32, reducing the batch_size to 1, and set max_seq_length to 2048. Unfortunately, the max_seq_length was impacted due to memory constraints.

The average iterations/second for the dataset is 0.045.

...
Iteration 106: Train loss 0.627, Iterations/second 0.010, Tokens/second 4.072
Iteration 107: Train loss 1.750, Iterations/second 0.063, Tokens/second 4.032
Iteration 108: Train loss 0.404, Iterations/second 0.076, Tokens/second 25.085
Iteration 109: Train loss 0.883, Iterations/second 0.036, Tokens/second 22.458
Iteration 110: Train loss 0.577, Iterations/second 0.022, Tokens/second 8.546
Iteration 111: Train loss 1.534, Iterations/second 0.053, Tokens/second 2.759
Iteration 112: Train loss 0.422, Iterations/second 0.023, Tokens/second 22.695
Iteration 113: Train loss 0.585, Iterations/second 0.015, Tokens/second 6.348
Iteration 114: Train loss 0.773, Iterations/second 0.013, Tokens/second 19.616
Iteration 115: Train loss 1.505, Iterations/second 0.014, Tokens/second 2.462
Iteration 116: Train loss 0.637, Iterations/second 0.045, Tokens/second 22.171
Iteration 117: Train loss 0.976, Iterations/second 0.019, Tokens/second 10.869
Iteration 118: Train loss 0.809, Iterations/second 0.027, Tokens/second 5.808
Iteration 119: Train loss 0.728, Iterations/second 0.042, Tokens/second 17.108
Iteration 120: Train loss 1.049, Iterations/second 0.090, Tokens/second 12.404
awni commented 9 months ago

@danilopeixoto does this command:

python -m mlx_lm.lora --train --model models/mixtral-8x7b-v0.1-8bit-64g/ --data datasets/chat-instruct/ --steps-per-report 1 --steps-per-eval 15 --save-every 15 --iters 500 --lora-layers 16 --batch-size 2

Still produce NaN for you? That is a model you quantized yourself right? Just wondering how I can reproduce that issue..

I've had success fine-tuning quantized mixtral models in the past without NaNs so there must be some instability that is only showing up in your setup.

Maybe it is related to the dataset? You are tuning on a custom dataset right?

Would be helpful if you could try decoupling your changes to see which ones resolved the NaN:

danilopeixoto commented 8 months ago

@awni This is amazing!

Were you using float16 or bfloat16 data type for fine-tuning?

I only managed to fine-tune successfully and get rid of nan losses by casting float16 values to float32 data type with mlx==v0.0.10. The changes in sequence length and batch size were necessary to accommodate float32 values in memory.

In addition, the experiment was using lambda m: isinstance(m, nn.Linear) as linear_class_predicate.

I will try to evaluate the experiment on the latest MLX version soon.

awni commented 8 months ago

In addition, the experiment was using lambda m: isinstance(m, nn.Linear) as linear_class_predicate.

That's another good tidbit. I never tried with quantized gates. It may not work.

danilopeixoto commented 8 months ago

@awni I recently had the chance to try out the latest version of MLX (v0.2.0), and indeed, the experiment with the pure linear predicate (lambda m: isinstance(m, nn.Linear)) was resulting in NaN losses during 16-bit (float16) fine-tuning of Mixtral.

https://github.com/ml-explore/mlx-examples/blob/f71e965d57c01a2b4ac3d43d72f75482dc80fe86/llms/mlx_lm/utils.py#L29

awni commented 8 months ago

I see.. we recently fixed a bug in our quantized kernels (https://github.com/ml-explore/mlx/pull/677) which may be related to this so maybe it will work in 0.3.0. Just to be sure it works if you don't quantize the gate right?

danilopeixoto commented 8 months ago

Yes. The fine-tuning status for the experiment in discussion:

Linear class predicate vs Precision float16 float32
lambda m: isinstance(m, nn.Linear) NaN losses Successful
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8 Successful -
danilopeixoto commented 8 months ago

I believe we can fine-tune the same experiment with a purely linear predicate using the bfloat16 data type in MLX version 0.3.0. However, I haven't had the opportunity to test it with the float16 data type yet.

The bfloat16 has also proven to be effective in fine-tuning a resized lm_head (with QLoRA) and Embeddings modules to incorporate additional special tokens without producing NaN losses.

Initial results

Prompt:

<|start|>system
You are MLX Assist, an AI assistant.
<|end|>
<|start|>user
Hello
<|end|>
<|start|>assistant

Completion:

Hi there! I'm MLX Assist, an AI assistant. I'm here to help with any questions or tasks you may have.
<|end|>
N8python commented 8 months ago

Can confirm the effectiveness of float32 end-to-end tuning on tinyllama.

ziozzang commented 7 months ago

https://github.com/ml-explore/mlx-examples/pull/645

I open full fine-tune code. :)

I tested on my M2 Mac Studio 192GB, phi 2 (2.8B). and It really works well.

awni commented 4 days ago

This is enabled in MLX LM now.

To fine-tune the full model weights, add the --fine-tune-type full flag. Currently supported fine-tuning types are lora (default), dora, and full.

From the docs.