Open Blaizzy opened 2 weeks ago
cc: @awni
So there are a couple things you should change in general about your Llama implementation:
nn.RMSNorm
instead of rolling your ownnn.RoPE
instead of rolling your ownThese will both be (much) faster and numerically more stable. The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean
in your case). The nn.RMSNorm
does this implicitly without the need for casting between mx.float32
and mx.float16
.
I double check most of your model files are using nn.RMSNorm
or nn.LayerNorm
when possible. And same for RoPE. The inference especially will be much faster.
Thanks a lot!
Yes, I was using a custom RMSNorm, I changed it to nn.RMSNorm
and it's 3.25x faster 🚀.
When it comes to rope I was already using nn.RoPE
since there are no changes needed and it's easier to integrate with cache
.
The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean in your case). The nn.RMSNorm does this implicitly without the need for casting between mx.float32 and mx.float16.
How did you check this?
@awni I made the recommended changes but I can't seem to be able to run training on my machine (M3 Max 96GB).
It throws an error after processing 3 samples even with batch size of 1.
{'Epoch': 0, 'Step': 0, 'Loss': '1.5820'}
3%|█▍ | 3/100 [00:09<05:03, 3.13s/it, Epoch=0, Step=0, Loss=1.5820]
zsh: segmentation fault python -m mlx_vlm.lora --model-path Llama-3.2-11B-Vision-Instruct-4bit
/opt/homebrew/Caskroom/miniconda/base/envs/mlx_code/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Could you please try it on your M2 ultra and see if the nan loss persists?
It's running on my M1 Max (32GB) with this command:
python -m mlx_vlm.lora --model-path Llama-3.2-11B-Vision-Instruct-4bit --dataset 5CD-AI/Viet-ShareGPT-4o-Text-VQA --split Viet_OCR_VQA --steps 100 --learning-rate 5e-6 --lora-rank 16 --lora-alpha 16
and the modifications to the dataset you posted above. So far it processed 11 steps no problem (I modified the print to print every step):
{'Epoch': 0, 'Step': 0, 'Loss': '1.5796'}
{'Epoch': 0, 'Step': 1, 'Loss': '1.8235'}
{'Epoch': 0, 'Step': 2, 'Loss': '1.9262'}
{'Epoch': 0, 'Step': 3, 'Loss': '1.5627'}
{'Epoch': 0, 'Step': 4, 'Loss': '1.5274'}
{'Epoch': 0, 'Step': 5, 'Loss': '1.7451'}
{'Epoch': 0, 'Step': 6, 'Loss': '1.9609'}
{'Epoch': 0, 'Step': 7, 'Loss': '0.9124'}
{'Epoch': 0, 'Step': 8, 'Loss': '1.7157'}
{'Epoch': 0, 'Step': 9, 'Loss': '1.6776'}
{'Epoch': 0, 'Step': 10, 'Loss': '1.8323'}
{'Epoch': 0, 'Step': 11, 'Loss': '1.4830'}
However, you should not be getting a segfault. That isn't good. Which version of MLX are you running? Anything else different in your setup?
Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.
Thanks!
Wow, that's really weird.
Here is my setup:
prince_canuma@MacBook-Pro-3 ~ % pip list | grep mlx
fastmlx 0.2.1
mlx 0.18.0
mlx-embeddings 0.0.1 /Users/prince_canuma/Documents/Projects/LLMs/mlx-embeddings
mlx-lm 0.19.0
mlx-vlm 0.1.0 /Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm
Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.
I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.
https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/trainer/trainer.py#L58
Could you try upgrading to the latest MLX (0.18.1) (and if it's used here MLX LM (0.19.1)) just to be sure we didn't fix something.. (I think this PR may be related: https://github.com/ml-explore/mlx/pull/1452)
Also remind me what's your machine and OS?
I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.
Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.
Upgrading to v0.18.1 fixed it! 🚀
Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.
Thank you! Do you have any tips specific to MLX?
When I started getting the error, I figure it could be the data loading so I made some initial rough optimizations like using a generator
and deleting the batch after processing and using the metal clear cache
command.
Also remind me what's your machine and OS?
Macbook Pro 14-inch Chip: M3 Max URAM: 96GB OS: Sonoma 14.5
Thank you! Do you have any tips specific to MLX?
First verify that data loading is in fact the issue. I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.
If data loading is the problem then look into what's actually slow. Is it the IO itself or some preprocessing steps?
and using the metal clear cache command.
I wouldn't manually clear the cache unless you have a really good reason. That will typically just slow everything down.
Awesome, thanks!
If you preload the dataset into RAM it probably isn't the IO
Do you do the preprocessing in MLX? If not, maybe try doing that so it runs fast on the GPU..
I preload/prefetch the batch before running it.
Then probably is the HF processor I use here for preparing the inputs is the bottleneck.
I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.
Could you elaborate here, I didn't quite get it.
Issue
I keep getting
nan
loss when training Llama-3.2-visionI tried:
But with no success.
Steps to reproduce:
pc/llama3.2-vision
branch.Add these two lines (31-32) to the
lora.py
to limit the dataset.Quantize model (Optional).
Start training.