jxiw / MambaInLlama

Official Repository of The Mamba in the Llama: Distilling and Accelerating Hybrid Models
https://arxiv.org/abs/2408.15237
Apache License 2.0
170 stars 12 forks source link

Training Slowdown for Llama3-Mamba2 #9

Open Codys12 opened 2 months ago

Codys12 commented 2 months ago

Hello! I am training the first two knowledge distillation stages of Mamba 2 on one DGX-H100x8 node, and I am experiencing train times of ~8 hours for the first stage, and ~13 hours for the second stage. Is this slowdown compared to the reported Mamba-1 train times expected?

Thank you!

jxiw commented 2 months ago

Thanks for this question.

I haven't tested the training time on the H100 since I don't have access to it, but it should be faster than the A100. On the A100, Mamba 1 takes 5 hours for 0.25 distillation and 5 hours for 0.5 distillation with Zephyr teacher models, and 6.3 hours + 6.5 hours for Llama 3 models.

For Mamba 2, which uses Llama 3 as the teacher model, it takes 6 hours for the first stage and 7.5 hours for the second.

Just make sure you're not using DeepSpeed Zero-3. In my case, I only used the simplest Torch DDP, which is the multi_gpu.yaml. If you use DeepSpeed Zero-3, it shards the model and optimizer across different GPUs, which may slow things down. Since I freeze most MLPs, I don’t need the memory efficiency features that would otherwise introduce more communication overhead to synchronize those optimizer states and etc. alternatively, you can also try zero-0, zero-1 or zero-2 which is more faster.

Codys12 commented 2 months ago

Odd, I am using mutli_gpu.yaml with the same command in the reproducibility guide swapping the names form mamba to mamba 2. The 0.5 distillation is taking 12 hours on the H100s, which seems way to high. I am checking to see if my causalconv1d is broken

Codys12 commented 2 months ago

It is using causalconv1d, so not sure why the training could be so slow...

jxiw commented 2 months ago

check this https://wandb.ai/junxiong12/mamba_distill/runs/5l1cvzxa/overview see whether it helps.

actually this node is the slower disk one. this one with standard disk takes 6 hours. https://wandb.ai/junxiong12/mamba_distill/runs/2gjsvg6c/overview?nw=nwuserjunxiong12

jxiw commented 2 months ago

Probably is that because of your mamba-ssm version? i use this, https://github.com/state-spaces/mamba/tree/49ddf8321e4987650e8dc8dc44caa44b892f207a.

Codys12 commented 2 months ago

I install from latest source, would you recommend that specific version?

jxiw commented 2 months ago

Sorry, I am just trying to figure out why it is slower in your side. It seems the latest one should be the same.

Codys12 commented 2 months ago

Is there any info I can get to you that could help diagnose? I can drop a pip freeze….

jxiw commented 2 months ago

Here are a few things to check that might help:

  1. What I remember is that in the initial steps, it shows a large runtime, but it becomes stable after some time.
  2. Did you save your checkpoint very frequently? In my case, I use an NSF disk, and it takes more than 30 minutes to save the DeepSpeed checkpoint.
  3. Make sure those dependencies are consistent with the requirements.txt.

If you figure out anything, please let me know.

Codys12 commented 2 months ago

I am going to follow the reproduction steps as closely as I can with the environment.yml. Thank you for all the help setting this up!

jxiw commented 2 months ago

Thank you. Just want to mention that, the SFT step takes most time. Probably, for the reproduction propose, you can consider to continue the following stages and try to fix the distillation speed later.

To help you best reproduce the speed, this is my environment, https://wandb.ai/junxiong12/mamba_distill/runs/2gjsvg6c/files/conda-environment.yaml. Hope that helps!

Codys12 commented 1 month ago

I was able to complete the distillation stages and I have moved on to SFT! The datasets are almost done processing. Unfortunately, I was not able to figure out the speed issue with the Mamba layers... If all else fails, is it possible to modify the config to train over two nodes for increased training throughput?

jxiw commented 1 month ago

Thanks for this great question. here is an example to enable train multiple node

deepspeed:

node 1

compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero3_save_16bit_model: false zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' main_training_function: main mixed_precision: bf16 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false machine_rank: 0 main_process_ip: xxx main_process_port: 29501

node 2

compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero3_save_16bit_model: false zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' main_training_function: main mixed_precision: bf16 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false machine_rank: 1 main_process_ip: xxx main_process_port: 29501

without deepspeed:

node 1

compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' gpu_ids: all main_training_function: main mixed_precision: bf16 machine_rank: 0 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false main_process_ip: xxx main_process_port: 29501

node 2

compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' gpu_ids: all main_training_function: main mixed_precision: bf16 machine_rank: 1 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false main_process_ip: xxx main_process_port: 29501

But the performance depends a bit of your network, e.g., whether NVLink or InfiniBand is enable. My cluster network is bad, so it does not give me speedup.