aliyun / aicb

Other
139 stars 21 forks source link

questions about 'megatron_workload_with_aiob.sh' and 'scripts/workload_moe.sh' #1

Closed KaiLv16 closed 3 months ago

KaiLv16 commented 4 months ago

Hi,

I was trying to reproduce the results in the Generate Workloads for Simulation (SimAI) section. I ran the recommended commands. However, I encountered the following exception:

root@sp-virtual-machine:/workspace/AICB# sh ./scripts/megatron_workload_with_aiob.sh \
-m 7 --world_size 4096 \
--tp_num 2 --pp_num 1 \
--comm_frame Megatron --global_batch 8192 \
--micro_batch 1 --seq_length 4096 \
--swiglu --use_flash_attn  --aiob_enable
python -m workload_generator.AIOB_simAI_workload_generator --comm_frame=Megatron --world_size=4096 --tp_num=4 --pp_num=1 --global_batch=8192 --micro_batch=1 --num_layers=36 --seq_length=4096 --hidden_size=4096 --epoch_num=1 --num_attention_heads=32 --model_name=gpt_7B --max_position_embeddings=4096 --vocab_size=50257 --use-distributed-optimizer --aiob_enable --use_flash_attn --swiglu
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/workload_generator/AIOB_simAI_workload_generator.py", line 606, in <module>
    args = get_params()
  File "/usr/local/lib/python3.10/dist-packages/utils/utils.py", line 357, in get_params
    assert (
AssertionError: moe must be enabled with sequence parallel

Do you have any ideas on how I can fix this?


In addition, I ran another command provided, but it seems to be unable to find the file 'scripts/workload_moe.sh'.

root@sp-virtual-machine:/workspace/AICB# sh scripts/workload_moe.sh \
-mmoe --world_size 4096 --tp_num 2 --pp_num 1 --sp  --expert_parallel_size 1 \
--num_moe_experts 2 --moe_router_topk 2  \
--comm_frame Megatron --global_batch 8192  \
--micro_batch 1 --seq_length 1024 --swiglu \
--use_flash_attn  --aiob_enable \
--comp_filepath workload/aiob_inputs/Example.txt
/bin/bash: scripts/workload_moe.sh: No such file or directory

Could this be due to the missing 'workload_moe.sh' file?

Thank you!

Huoyuan100861 commented 4 months ago

The currently supported Expert parallelism adopts the implementation from Megatron-Core. In Megatron-Core, "When using expert parallelism and tensor parallelism, sequence parallelism must be used" will be seen in the link https://github.com/NVIDIA/Megatron-LM/blob/6dbe4cf699880038b1e5cd90b23ee71053c7f2ee/megatron/core/model_parallel_config.py#L333. We will continuously update and enhance various of Expert parallelism in the future.

zhouheyang-alibaba commented 4 months ago

Apologies for the small error in the README. The current script supporting workload generation for Moe and expert parallelism is scripts/megatron_gpt.sh and for simAI workload scripts is scripts/megatron_workload_with_aiob.sh . Currently, for Moe, only the alltoalltokendispather functionalities are supported. The v1.0 version still lacks some aspects for running Moe-related workloads on physical machines, but these issues will be addressed in the upcoming version!

JoongunPark commented 3 months ago

Hi, I also met the same problem. Is there any working example that I can test SimAI? Thanks!

1195343015 commented 3 months ago

Hi, I also met the same problem. Is there any working example that I can test SimAI? Thanks!

it seems like some bug in utils.py I think it will be helpful to edit the code here https://github.com/aliyun/aicb/blob/cd91399267252cd8cd18bb185c1980606bf0c014/utils/utils.py#L357-L359

add a line if args.moe_enabled :

if args.moe_enabled :
   assert ( 
       args.moe_enabled and args.enable_sequence_parallel 
   ), f"moe must be enabled with sequence parallel" 
zhouheyang-alibaba commented 3 months ago

Thank you for pointing out the issue. I will review it and make the necessary fixes in the upcoming version.

KaiLv16 commented 3 months ago

Hi, I also met the same problem. Is there any working example that I can test SimAI? Thanks!

it seems like some bug in utils.py I think it will be helpful to edit the code here

https://github.com/aliyun/aicb/blob/cd91399267252cd8cd18bb185c1980606bf0c014/utils/utils.py#L357-L359

add a line if args.moe_enabled :

if args.moe_enabled :
   assert ( 
       args.moe_enabled and args.enable_sequence_parallel 
   ), f"moe must be enabled with sequence parallel" 

That works. Nice job!