Blaizzy / mlx-vlm

MLX-VLM is a package for running Vision LLMs locally on your Mac using MLX.
MIT License
436 stars 35 forks source link

Nan loss when training Llama-3.2-vision #84

Open Blaizzy opened 2 weeks ago

Blaizzy commented 2 weeks ago

Issue

I keep getting nan loss when training Llama-3.2-vision

I tried:

But with no success.

Steps to reproduce:

  1. Install pc/llama3.2-vision branch.
pip install -U git+https://github.com/Blaizzy/mlx-vlm.git@pc/llama3.2-vision
  1. Add these two lines (31-32) to the lora.py to limit the dataset.

    dataset = load_dataset(args.dataset, split=args.split+"[:20%]")
    dataset = dataset.rename_columns({"image": "images", "conversations": "messages"})
  2. Quantize model (Optional).

    python -m mlx_vlm.convert --hf-path unsloth/Llama-3.2-11B-Vision-Instruct -q --mlx-path Llama-3.2-11B-Vision-Instruct-4bit
  3. Start training.

    python -m mlx_vlm.lora --model-path  Llama-3.2-11B-Vision-Instruct-4bit --dataset
    5CD-AI/Viet-ShareGPT-4o-Text-VQA --split Viet_OCR_VQA --steps 100 --learning-rate 5e-6 --lora-rank 16 --lora-alpha 16
Screenshot 2024-10-13 at 8 44 55 PM
Blaizzy commented 2 weeks ago

cc: @awni

awni commented 2 weeks ago

So there are a couple things you should change in general about your Llama implementation:

  1. Use nn.RMSNorm instead of rolling your own
  2. Use nn.RoPE instead of rolling your own

These will both be (much) faster and numerically more stable. The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean in your case). The nn.RMSNorm does this implicitly without the need for casting between mx.float32 and mx.float16.

I double check most of your model files are using nn.RMSNorm or nn.LayerNorm when possible. And same for RoPE. The inference especially will be much faster.

Blaizzy commented 2 weeks ago

Thanks a lot!

Yes, I was using a custom RMSNorm, I changed it to nn.RMSNorm and it's 3.25x faster 🚀.

When it comes to rope I was already using nn.RoPE since there are no changes needed and it's easier to integrate with cache.

The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean in your case). The nn.RMSNorm does this implicitly without the need for casting between mx.float32 and mx.float16.

How did you check this?

Blaizzy commented 2 weeks ago

@awni I made the recommended changes but I can't seem to be able to run training on my machine (M3 Max 96GB).

It throws an error after processing 3 samples even with batch size of 1.

{'Epoch': 0, 'Step': 0, 'Loss': '1.5820'}
  3%|█▍                                            | 3/100 [00:09<05:03,  3.13s/it, Epoch=0, Step=0, Loss=1.5820]
zsh: segmentation fault  python -m mlx_vlm.lora --model-path Llama-3.2-11B-Vision-Instruct-4bit
/opt/homebrew/Caskroom/miniconda/base/envs/mlx_code/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Could you please try it on your M2 ultra and see if the nan loss persists?

awni commented 2 weeks ago

It's running on my M1 Max (32GB) with this command:

python -m mlx_vlm.lora --model-path  Llama-3.2-11B-Vision-Instruct-4bit --dataset 5CD-AI/Viet-ShareGPT-4o-Text-VQA --split Viet_OCR_VQA --steps 100 --learning-rate 5e-6 --lora-rank 16 --lora-alpha 16

and the modifications to the dataset you posted above. So far it processed 11 steps no problem (I modified the print to print every step):

{'Epoch': 0, 'Step': 0, 'Loss': '1.5796'}
{'Epoch': 0, 'Step': 1, 'Loss': '1.8235'}
{'Epoch': 0, 'Step': 2, 'Loss': '1.9262'}
{'Epoch': 0, 'Step': 3, 'Loss': '1.5627'}
{'Epoch': 0, 'Step': 4, 'Loss': '1.5274'}
{'Epoch': 0, 'Step': 5, 'Loss': '1.7451'}
{'Epoch': 0, 'Step': 6, 'Loss': '1.9609'}
{'Epoch': 0, 'Step': 7, 'Loss': '0.9124'}
{'Epoch': 0, 'Step': 8, 'Loss': '1.7157'}
{'Epoch': 0, 'Step': 9, 'Loss': '1.6776'}
{'Epoch': 0, 'Step': 10, 'Loss': '1.8323'}
{'Epoch': 0, 'Step': 11, 'Loss': '1.4830'}

However, you should not be getting a segfault. That isn't good. Which version of MLX are you running? Anything else different in your setup?

Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.

Blaizzy commented 2 weeks ago

Thanks!

Wow, that's really weird.

Here is my setup:

prince_canuma@MacBook-Pro-3 ~ % pip list | grep mlx
fastmlx                                   0.2.1
mlx                                       0.18.0
mlx-embeddings                            0.0.1             /Users/prince_canuma/Documents/Projects/LLMs/mlx-embeddings
mlx-lm                                    0.19.0
mlx-vlm                                   0.1.0             /Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm

Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.

I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.

https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/trainer/trainer.py#L58

awni commented 2 weeks ago

Could you try upgrading to the latest MLX (0.18.1) (and if it's used here MLX LM (0.19.1)) just to be sure we didn't fix something.. (I think this PR may be related: https://github.com/ml-explore/mlx/pull/1452)

Also remind me what's your machine and OS?

I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.

Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.

Blaizzy commented 2 weeks ago

Upgrading to v0.18.1 fixed it! 🚀

Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.

Thank you! Do you have any tips specific to MLX?

When I started getting the error, I figure it could be the data loading so I made some initial rough optimizations like using a generator and deleting the batch after processing and using the metal clear cache command.

Blaizzy commented 2 weeks ago

Also remind me what's your machine and OS?

Macbook Pro 14-inch Chip: M3 Max URAM: 96GB OS: Sonoma 14.5

awni commented 2 weeks ago

Thank you! Do you have any tips specific to MLX?

First verify that data loading is in fact the issue. I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.

If data loading is the problem then look into what's actually slow. Is it the IO itself or some preprocessing steps?

and using the metal clear cache command.

I wouldn't manually clear the cache unless you have a really good reason. That will typically just slow everything down.

Blaizzy commented 2 weeks ago

Awesome, thanks!

If you preload the dataset into RAM it probably isn't the IO

Do you do the preprocessing in MLX? If not, maybe try doing that so it runs fast on the GPU..

I preload/prefetch the batch before running it.

Then probably is the HF processor I use here for preparing the inputs is the bottleneck.

I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.

Could you elaborate here, I didn't quite get it.