databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Seeking a good multi-node training config #92

Open rpand002 opened 5 months ago

rpand002 commented 5 months ago

Thanks for the excellent work. Following the comment in #59, I am trying to train dmoe_760m using 16 GPUs (2 nodes) by changing distributed arguments to set up for two nodes but it is very slow in terms of elapsed time per iteration (ms). Can you suggest an optimal training configuration for multi-node training? A full-fledged multi-training script would be very helpful.

@tgale96

#!/bin/bash
export PYTHONPATH="/dataset/g_ckpt/gaoyuanz/megablocks-public/third_party/Granite-Megatron-LM:${PYTHONPATH}"

export NCCL_SOCKET_IFNAME="ib,bond"
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_IB_PCI_RELAXED_ORDERING=1
export UCX_IB_PCI_RELAXED_ORDERING=on
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export NCCL_SOCKET_NTHREADS=2
export NCCL_NSOCKS_PERTHREAD=4
export CUDA_DEVICE_MAX_CONNECTIONS=1
export TOKENIZERS_PARALLELISM=false

MASTER_ADDR=$(echo ${LSB_MCPU_HOSTS} | tr ' ' '\n' | head -n 1)
MASTER_PORT=5${LSB_JOBID: -5:-1}
NNODES=$(echo ${LSB_MCPU_HOSTS} | tr ' ' '\n' | sed 'n; d' | wc -w)
GPUS_PER_NODE=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -w)
NODE_RANK=$(($(echo ${LSB_MCPU_HOSTS} | tr ' ' '\n' | sed 'n; d' | grep -n -m1 $HOSTNAME | cut -d':' -f1)-1))

EXPERIMENT_NAME="g-moe-1x4"

EXP_DIR="g-dmoe"

# scaling law: 16B tokens @ 760M = 32k steps.
#
# 512 * 1k * 400k = 200b tokens.
# 512 * 1k * 200k = 100b tokens.
# 512 * 1k * 100k = 50b tokens (default).
# 512 * 1k * 20k = 10b tokens.
TRAINING_STEPS=20000
if [ -n "${2}" ]; then
    TRAINING_STEPS=$2;
fi

##
### Pre-training for GPT2 762M parameter.
##

# MoE hyperparameters.
MOE_ARGUMENTS="\
--moe-num-experts=64 \
--moe-loss-weight=0.1 \
--moe-top-k=1"

# Distributed hyperparameters.
DISTRIBUTED_ARGUMENTS="\
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"

# Model hyperparameters.
MODEL_ARGUMENTS="\
--num-layers 24 \
--hidden-size 1536 \
--num-attention-heads 16 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--activation-function gelu \
--attention-head-type multihead \
--normalization-function layernorm"

# Training hyperparameters.
TRAINING_ARGUMENTS="\
--micro-batch-size 4 \
--global-batch-size 2048 \
--train-iters ${TRAINING_STEPS} \
--lr-decay-iters ${TRAINING_STEPS} \
--lr 0.0004 \
--min-lr 0.00004 \
--lr-decay-style cosine \
--lr-warmup-fraction 0.01 \
--clip-grad 1.0 \
--init-method-std 0.01"

PILE_DATASET="\
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk1 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk2 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk3 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk4 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk5 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk6 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk7 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk8 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk9 \
1.0 \
/dataset/bluepile/slim_pajama_gptneox_megatron/train/chunk10"

# NOTE: We don't train for enough tokens for the
# split to matter.
DATA_ARGUMENTS="\
--data-path ${PILE_DATASET} \
--tokenizer-type HuggingFaceTokenizer \
--tokenizer-path /dataset/g_ckpt/cobol_exp/Granite-Megatron-LM/tokenizers/gpt-neox-20b \
--make-vocab-size-divisible-by 1024 \
--split 969,30,1"

COMPUTE_ARGUMENTS="\
--fp16 \
--DDP-impl local \
--moe-expert-model-parallelism \
--no-async-tensor-model-parallel-allreduce \
--use-flash-attn"

CHECKPOINT_ARGUMENTS="\
--save-interval 2000 \
--save ./${EXP_DIR}"

EVALUATION_ARGUMENTS="\
--eval-iters 100 \
--log-interval 1 \
--eval-interval 1000"

python -m torch.distributed.launch ${DISTRIBUTED_ARGUMENTS} \
       pretrain_gpt.py \
       ${MOE_ARGUMENTS} \
       ${MODEL_ARGUMENTS} \
       ${TRAINING_ARGUMENTS} \
       ${DATA_ARGUMENTS} \
       ${COMPUTE_ARGUMENTS} \
       ${CHECKPOINT_ARGUMENTS} \
       --fix-infiniband \
       ${EVALUATION_ARGUMENTS} |& tee ./${EXP_DIR}/train-20k.log
tgale96 commented 5 months ago

You're using our Megatron fork with MegaBlocks integrated? What kind of system are you on? A100, H100, etc.?

Soonhwan-Kwon commented 5 months ago

@tgale96 Thank you for the great work. I experienced the same slow down as @rpand002. I'm using A100 system, and w/ your megatron fork. Multi-training script for the reference will be a great help.

tgale96 commented 5 months ago

Our Megatron fork is mostly for small-scale experiments and uses the data parallel process group for expert model parallelism. If you scale out to multiple nodes with data parallelism and expert parallelism enabled you'll do expert parallelism across those nodes, which can be slow because the all2alls become a bit expensive.

One thing you could try is using pipeline parallelism between nodes. If you were to use MegaBlocks in a custom framework, I'd recommend using something like FSDP across nodes and expert parallelism within each node.

I do not have reference scripts for multi-node training, but for pipeline parallelism the flags are the same as they are in upstream Megatron-LM. I hope this helps!