Open Codys12 opened 2 months ago
Thanks for this question.
I haven't tested the training time on the H100 since I don't have access to it, but it should be faster than the A100. On the A100, Mamba 1 takes 5 hours for 0.25 distillation and 5 hours for 0.5 distillation with Zephyr teacher models, and 6.3 hours + 6.5 hours for Llama 3 models.
For Mamba 2, which uses Llama 3 as the teacher model, it takes 6 hours for the first stage and 7.5 hours for the second.
Just make sure you're not using DeepSpeed Zero-3. In my case, I only used the simplest Torch DDP, which is the multi_gpu.yaml. If you use DeepSpeed Zero-3, it shards the model and optimizer across different GPUs, which may slow things down. Since I freeze most MLPs, I don’t need the memory efficiency features that would otherwise introduce more communication overhead to synchronize those optimizer states and etc. alternatively, you can also try zero-0, zero-1 or zero-2 which is more faster.
Odd, I am using mutli_gpu.yaml with the same command in the reproducibility guide swapping the names form mamba to mamba 2. The 0.5 distillation is taking 12 hours on the H100s, which seems way to high. I am checking to see if my causalconv1d is broken
It is using causalconv1d, so not sure why the training could be so slow...
check this https://wandb.ai/junxiong12/mamba_distill/runs/5l1cvzxa/overview see whether it helps.
actually this node is the slower disk one. this one with standard disk takes 6 hours. https://wandb.ai/junxiong12/mamba_distill/runs/2gjsvg6c/overview?nw=nwuserjunxiong12
Probably is that because of your mamba-ssm version? i use this, https://github.com/state-spaces/mamba/tree/49ddf8321e4987650e8dc8dc44caa44b892f207a.
I install from latest source, would you recommend that specific version?
Sorry, I am just trying to figure out why it is slower in your side. It seems the latest one should be the same.
Is there any info I can get to you that could help diagnose? I can drop a pip freeze….
Here are a few things to check that might help:
If you figure out anything, please let me know.
I am going to follow the reproduction steps as closely as I can with the environment.yml. Thank you for all the help setting this up!
Thank you. Just want to mention that, the SFT step takes most time. Probably, for the reproduction propose, you can consider to continue the following stages and try to fix the distillation speed later.
To help you best reproduce the speed, this is my environment, https://wandb.ai/junxiong12/mamba_distill/runs/2gjsvg6c/files/conda-environment.yaml. Hope that helps!
I was able to complete the distillation stages and I have moved on to SFT! The datasets are almost done processing. Unfortunately, I was not able to figure out the speed issue with the Mamba layers... If all else fails, is it possible to modify the config to train over two nodes for increased training throughput?
Thanks for this great question. here is an example to enable train multiple node
deepspeed:
compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero3_save_16bit_model: false zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' main_training_function: main mixed_precision: bf16 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false machine_rank: 0 main_process_ip: xxx main_process_port: 29501
compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero3_save_16bit_model: false zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' main_training_function: main mixed_precision: bf16 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false machine_rank: 1 main_process_ip: xxx main_process_port: 29501
without deepspeed:
compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' gpu_ids: all main_training_function: main mixed_precision: bf16 machine_rank: 0 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false main_process_ip: xxx main_process_port: 29501
compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' gpu_ids: all main_training_function: main mixed_precision: bf16 machine_rank: 1 num_machines: 2 num_processes: 16 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false main_process_ip: xxx main_process_port: 29501
But the performance depends a bit of your network, e.g., whether NVLink or InfiniBand is enable. My cluster network is bad, so it does not give me speedup.
Hello! I am training the first two knowledge distillation stages of Mamba 2 on one DGX-H100x8 node, and I am experiencing train times of ~8 hours for the first stage, and ~13 hours for the second stage. Is this slowdown compared to the reported Mamba-1 train times expected?
Thank you!