Closed fakerybakery closed 4 days ago
Sure, we can add phi to our LoRA example! @mzbac already did some great work to merge Phi into the generation example. So from there it should be pretty straightforward. I will try and get to it shortly.
Thanks! Is it possible to add an example of tuning Phi without LoRA (ie full-tuning)? Thanks!
I would love this. I try unfreezing the model but that just leads to NaN loss.
Full fine-tuning can be tricky because it can easily lead to catastrophic forgetting. Even fully layer lora for phi2 could cause some issues. I'm not very sure if the mlx team would provide an easy example for that, but it would be great to get some insight from Awni.
Sure, we can add phi to our LoRA example! @mzbac already did some great work to merge Phi into the generation example. So from there it should be pretty straightforward. I will try and get to it shortly.
Added support for it in https://github.com/ml-explore/mlx-examples/pull/302
I would love this. I try unfreezing the model but that just leads to NaN loss. @mzbac makes good points. LoRA fine-tuning is much more stable for these large models especially on small datasets.
Nevertheless, the NaN is concerning... I suspect it's because you are fine-tuning in fp16 which won't work well (the range can easily overflow). You might have better luck using fp32 at least to start, but it will be expensive memory / compute wise.
I have the memory and the compute - how would I cast everything to fp32? just .asdtype?
You can do something like:
module.update(tree_map(lambda p: p.astype(mx.float32), module.parameters()))
Wow thanks!
Somehow 64GB wasn't enough for a 7B float32 finetune... Adam eats up memory LOL.
Confirmed, SGD works.
But my god, float32 is brutal. 1/10th the speed of float16...
It shouldn't be 1/10th.. that probably means it's swapping :. Unfortunately, fine-tuning in 32-bit precision is very memory hungry.. it's uncommon to use 32-bit even for pre-training with such large models.
You can try:
How come mlx fails in 16-bit if most big models are pretrained that way? Is it cause it doesn't use bfloat16?
People do float16 and bfloat16 but both cases (typically) require modifications to actually make full training work. bfloat16 is easier than float16, but it sill often won't work with a naive SGD replacing fp32 with bf16.
Got it. Thank you for the info!
Already, to train in FP16, you need a ton of memory; the 7B param model will need 14GB of memory in FP16 just to load the weights. The gradients would be as significant, so another 14GB, and depending on the optimizer, you may need to store extra parameters. For instance, Adams stores 2 extra parameters, so another 28GBs.... You also need to account of the overhead of computing the intermediate layers outputs...
How exactly do you switch whole training / finetuning pipeline to FP16 or BF16. I can't fine any arguments that the optimizer take? I tried converting all the weights after the model is created to float16, but that didn't work.
I tried converting all the weights after the model is created to float16, but that didn't work.
What exactly do you mean by "didn't work"? In general that should work so there is either a bug in your conversion or a bug in the code you were using to train. In some cases we accidentally upcast (which is a bug) so we have to make sure we aren't doing it.
I'm trying to train an LLM using this model: https://github.com/ml-explore/mlx-examples/blob/main/llms/llama/llama.py I've duplicated the code and added this to convert the weights: def convert_to_f16(self): self.apply(self.weigths_to_f16)
def weigths_to_f16(self, m): m = m.astype(dtype=mx.float16)
Even with absolutely tiny learning rate 1-e6 I get nan in loss.
I am currently fine-tuning a quantized version of the Mixtral 8x7B model using a MacBook Pro M3 Max with 64GB of memory. However, I've encountered an issue where I'm seeing nan
loss values emerging after a number of training iterations.
Command:
# max_seq_length=8192
python -m mlx_lm.lora --train --model models/mixtral-8x7b-v0.1-8bit-64g/ --data datasets/chat-instruct/ --steps-per-report 1 --steps-per-eval 15 --save-every 15 --iters 500 --lora-layers 16 --batch-size 2
Output:
Loading pretrained model
Total parameters 13231.789M
Trainable parameters 2.229M
Loading datasets
Training
Starting training..., iters: 500
Iter 1: Train loss 1.198, It/sec 0.008, Tokens/sec 9.854
Iter 1: Val loss 1.133, Val took 200.444s
Iter 2: Train loss 1.211, It/sec 0.012, Tokens/sec 15.892
Iter 3: Train loss 1.914, It/sec 0.014, Tokens/sec 4.979
Iter 4: Train loss 1.188, It/sec 0.023, Tokens/sec 19.030
Iter 5: Train loss 0.985, It/sec 0.013, Tokens/sec 9.173
Iter 6: Train loss 0.542, It/sec 0.007, Tokens/sec 10.756
Iter 7: Train loss 1.434, It/sec 0.011, Tokens/sec 4.589
Iter 8: Train loss nan, It/sec 0.035, Tokens/sec 50.523
Iter 9: Train loss nan, It/sec 0.076, Tokens/sec 42.769
Iter 10: Train loss nan, It/sec 0.023, Tokens/sec 65.747
I did cast the model to the bfloat16
data type, but this did not resolve the issue.
As @awni pointed out, training with 16-bit precision is prone to numerical stability issues and requires additional implementation.
Here's an article on mixed precision training that might provide additional insights: https://medium.com/@jbensnyder/solving-the-limits-of-mixed-precision-training-231019128b4b
My current solution involved casting float16
parameters to float32
, reducing the batch_size
to 1
, and set max_seq_length
to 2048
. Unfortunately, the max_seq_length
was impacted due to memory constraints.
The average iterations/second
for the dataset is 0.045
.
...
Iteration 106: Train loss 0.627, Iterations/second 0.010, Tokens/second 4.072
Iteration 107: Train loss 1.750, Iterations/second 0.063, Tokens/second 4.032
Iteration 108: Train loss 0.404, Iterations/second 0.076, Tokens/second 25.085
Iteration 109: Train loss 0.883, Iterations/second 0.036, Tokens/second 22.458
Iteration 110: Train loss 0.577, Iterations/second 0.022, Tokens/second 8.546
Iteration 111: Train loss 1.534, Iterations/second 0.053, Tokens/second 2.759
Iteration 112: Train loss 0.422, Iterations/second 0.023, Tokens/second 22.695
Iteration 113: Train loss 0.585, Iterations/second 0.015, Tokens/second 6.348
Iteration 114: Train loss 0.773, Iterations/second 0.013, Tokens/second 19.616
Iteration 115: Train loss 1.505, Iterations/second 0.014, Tokens/second 2.462
Iteration 116: Train loss 0.637, Iterations/second 0.045, Tokens/second 22.171
Iteration 117: Train loss 0.976, Iterations/second 0.019, Tokens/second 10.869
Iteration 118: Train loss 0.809, Iterations/second 0.027, Tokens/second 5.808
Iteration 119: Train loss 0.728, Iterations/second 0.042, Tokens/second 17.108
Iteration 120: Train loss 1.049, Iterations/second 0.090, Tokens/second 12.404
@danilopeixoto does this command:
python -m mlx_lm.lora --train --model models/mixtral-8x7b-v0.1-8bit-64g/ --data datasets/chat-instruct/ --steps-per-report 1 --steps-per-eval 15 --save-every 15 --iters 500 --lora-layers 16 --batch-size 2
Still produce NaN for you? That is a model you quantized yourself right? Just wondering how I can reproduce that issue..
I've had success fine-tuning quantized mixtral models in the past without NaNs so there must be some instability that is only showing up in your setup.
Maybe it is related to the dataset? You are tuning on a custom dataset right?
Would be helpful if you could try decoupling your changes to see which ones resolved the NaN:
@awni This is amazing!
Were you using float16
or bfloat16
data type for fine-tuning?
I only managed to fine-tune successfully and get rid of nan
losses by casting float16
values to float32
data type with mlx==v0.0.10
. The changes in sequence length and batch size were necessary to accommodate float32
values in memory.
In addition, the experiment was using lambda m: isinstance(m, nn.Linear)
as linear_class_predicate
.
I will try to evaluate the experiment on the latest MLX version soon.
In addition, the experiment was using lambda m: isinstance(m, nn.Linear) as linear_class_predicate.
That's another good tidbit. I never tried with quantized gates. It may not work.
@awni I recently had the chance to try out the latest version of MLX (v0.2.0
), and indeed, the experiment with the pure linear predicate (lambda m: isinstance(m, nn.Linear)
) was resulting in NaN losses during 16-bit (float16
) fine-tuning of Mixtral.
I see.. we recently fixed a bug in our quantized kernels (https://github.com/ml-explore/mlx/pull/677) which may be related to this so maybe it will work in 0.3.0
. Just to be sure it works if you don't quantize the gate right?
Yes. The fine-tuning status for the experiment in discussion:
Linear class predicate vs Precision | float16 |
float32 |
---|---|---|
lambda m: isinstance(m, nn.Linear) |
NaN losses | Successful |
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8 |
Successful | - |
I believe we can fine-tune the same experiment with a purely linear predicate using the bfloat16
data type in MLX version 0.3.0
. However, I haven't had the opportunity to test it with the float16
data type yet.
The bfloat16
has also proven to be effective in fine-tuning a resized lm_head
(with QLoRA) and Embeddings
modules to incorporate additional special tokens without producing NaN losses.
Initial results
Prompt:
<|start|>system
You are MLX Assist, an AI assistant.
<|end|>
<|start|>user
Hello
<|end|>
<|start|>assistant
Completion:
Hi there! I'm MLX Assist, an AI assistant. I'm here to help with any questions or tasks you may have.
<|end|>
Can confirm the effectiveness of float32 end-to-end tuning on tinyllama.
https://github.com/ml-explore/mlx-examples/pull/645
I open full fine-tune code. :)
I tested on my M2 Mac Studio 192GB, phi 2 (2.8B). and It really works well.
Hi, It would be great to have an example of finetuning Phi without LoRA or QLoRA. Thanks!