microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.96k stars 4.06k forks source link

[BUG] Failed to inference Megatron gpt-3 MoE model with `deepspeed.init_inference` #2183

Closed Gabriel4256 closed 1 year ago

Gabriel4256 commented 2 years ago

Describe the bug When I use deepspeed.init_inference for Megatron gpt-3 MoE model in megatron repo, error occurs. There's no problem when I use deepspeed.init instead as done in the training stage, but it seems init_inferrence is a proper function for inference. If there is no difference in performace at all, I will just use deepspeed.init.

To Reproduce Steps to reproduce the behavior:

I used ds_pretrain_gpt_1.3B_MoE128.sh and pretrain_gpt.py with some modifications.

contents of ds_pretrain_gpt_1.3B_MoE128.sh (I used GPT-3 small 125M and, changed some gpu settings):

#!/bin/bash
DIR=`pwd`
###############################################################################
### Main configs
## GPT-3 models use 2K sequence length/context window
SEQ_LEN=2048

### The "GPT-3 XXX" below are configs from GPT-3 paper
### https://arxiv.org/abs/2005.14165, choose based on
### your desired model size or build your own configs

## GPT-3 Small 125M
MODEL_SIZE=0.125
NUM_LAYERS=12
HIDDEN_SIZE=768
NUM_ATTN_HEADS=12
GLOBAL_BATCH_SIZE=256
# LR=6.0e-4
# MIN_LR=6.0e-5

## GPT-3 Medium 350M
# MODEL_SIZE=0.35
# NUM_LAYERS=24
# HIDDEN_SIZE=1024
# NUM_ATTN_HEADS=16
# GLOBAL_BATCH_SIZE=256
# LR=3.0e-4
# MIN_LR=3.0e-5

## GPT-3 Large 760M
# MODEL_SIZE=0.76
# NUM_LAYERS=24
# HIDDEN_SIZE=1536
# NUM_ATTN_HEADS=16
# GLOBAL_BATCH_SIZE=256
# LR=2.5e-4
# MIN_LR=2.5e-5

## GPT-3 XL 1.3B
# MODEL_SIZE=1.3
# NUM_LAYERS=24
# HIDDEN_SIZE=2048
# NUM_ATTN_HEADS=16
# GLOBAL_BATCH_SIZE=512
# LR=2.0e-4
# MIN_LR=2.0e-5

## GPT-3 2.7B
# MODEL_SIZE=2.7
# NUM_LAYERS=32
# HIDDEN_SIZE=2560
# NUM_ATTN_HEADS=32
# GLOBAL_BATCH_SIZE=512
# LR=1.6e-4
# MIN_LR=1.6e-5

## GPT-3 6.7B
# MODEL_SIZE=6.7
# NUM_LAYERS=32
# HIDDEN_SIZE=4096
# NUM_ATTN_HEADS=32
# GLOBAL_BATCH_SIZE=1024
# LR=1.2e-4
# MIN_LR=1.2e-5

## GPT-3 13B
# MODEL_SIZE=13
# NUM_LAYERS=40
# HIDDEN_SIZE=5120
# NUM_ATTN_HEADS=40
# GLOBAL_BATCH_SIZE=1024
# LR=1.0e-4
# MIN_LR=1.0e-5

## GPT-3 175B
# MODEL_SIZE=175
# NUM_LAYERS=96
# HIDDEN_SIZE=12288
# NUM_ATTN_HEADS=96
# GLOBAL_BATCH_SIZE=1536
# LR=0.6e-4
# MIN_LR=0.6e-5
###############################################################################
### Training duration configs
## The main termination condition, original GPT-3 paper trains for 300B tokens
## For MoE model, we found sometimes training a bit more to 330B tokens helps
TRAIN_TOKENS=300000000000
# TRAIN_TOKENS=330000000000

## TRAIN_ITERS is another termination condition and also affect the number of 
## data samples to be indexed. Since we want to reach the TRAIN_TOKENS
## above, and techniques like curriculum learning has less token in some steps,
## so we just set this config large enough to make sure we have enough
## processed data and don't terminate by TRAIN_ITERS.
TRAIN_ITERS=$(( ${TRAIN_TOKENS} * 3 / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} ))
# TRAIN_ITERS=0

## Another termination condition in minutes. Set it large enough to avoid
## undesired early termination.
EXIT_DURATION=300
###############################################################################
### LR configs
## LR warmup and decay duration, this token-based config is preferable since
## no need to readjust when the batch size/seqlen is changed.
## Original GPT-3 paper uses 375M warmup tokens and 260B decay tokens.
## For MoE model, we found that setting the decay token to 300B helps.
WARMUP_TOKENS=375000000
# LR_DECAY_TOKENS=260000000000
LR_DECAY_TOKENS=300000000000
###############################################################################
### Parallelism configs
## Micro batch size per GPU
## Make sure that BATCH_SIZE <= GLOBAL_BATCH_SIZE*PP_SIZE*MP_SIZE/NUM_GPUS
BATCH_SIZE=8

## Model parallelism, 1 is no MP
## Currently MoE models have divergence issue when MP > 1.
MP_SIZE=1

## Pipeline parallelism
## Currently we don't support PP for MoE. To disable PP, set PP_SIZE
## to 1 and use the "--no-pipeline-parallel" arg.
PP_SIZE=1
NUM_GPUS=2
###############################################################################
### MoE configs
## Number of experts. EP_SIZE 1 means dense model without MoE
# EP_SIZE=1
EP_SIZE=2

if [[ $EP_SIZE -gt $NUM_GPUS ]]; then
    EP_PARALLEL_SIZE=$NUM_GPUS
else
    EP_PARALLEL_SIZE=$EP_SIZE
fi

## Original GPT-3 model always set min LR at 10% of max LR. For MoE model, we
## found that lower LR and min LR (than the base dense model) helps.
## For 1.3B MoE-128 model we used LR=1.2e-4 and MIN_LR=1.0e-6.
## For 350M MoE-128 model we used LR=2.0e-4 and MIN_LR=2.0e-6, but they are not
## heavily tuned.
LR=1.2e-4
MIN_LR=1.0e-6

## Coefficient for MoE loss. We find that 0.01 is a good value at least for
## 1.3B MoE-128 model
MLC=0.01

## Below configs adjust the MoE expert token capacity limit during training and
## eval. To completely disable capacity limit, set MOE_DROP_TOKEN to false.
## Larger capacity factor or disabling capacity limit could improve training
## convergence, but will also reduce training throughput.
MOE_TRAIN_CAP_FACTOR=1.0
MOE_EVAL_CAP_FACTOR=1.0
MOE_MIN_CAP=4
MOE_DROP_TOKEN="true"
# MOE_DROP_TOKEN="false"
###############################################################################
### Curriculum learning (CL) configs
## Enable/disable CL
CL_ENABLED="false"
## Consult the tutorial https://www.deepspeed.ai/tutorials/curriculum-learning/
## for tuning the following configs
CL_START_SEQLEN=80
CL_AVG_SEQLEN=$(( (${CL_START_SEQLEN} + ${SEQ_LEN}) / 2 ))
CL_TOKENS=60
CL_TOKENS=$((${CL_TOKENS} * 1000000000))
CL_STEP=$(( ${CL_TOKENS} / (${GLOBAL_BATCH_SIZE} * ${CL_AVG_SEQLEN}) ))
###############################################################################
### Misc configs
LOG_INTERVAL=10
EVAL_ITERS=10
EVAL_INTERVAL=100
SAVE_INTERVAL=10000

## Standard deviation for weight initialization
## We used 0.014 for 350M/1.3B dense/MoE models, and used 0.01 for 6.7B
## dense model. Usually larger model needs lower std.
INIT_STD=0.014
# INIT_STD=0.01

## Activation checkpointing saves GPU memory, but reduces training speed
ACTIVATION_CHECKPOINT="true"
# ACTIVATION_CHECKPOINT="false"
###############################################################################
### Output and data configs
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
host="${HOSTNAME}"
NAME="gpt-${MODEL_SIZE}B-lr-${LR}-minlr-${MIN_LR}-bs-${GLOBAL_BATCH_SIZE}-gpus-${NUM_GPUS}-mp-${MP_SIZE}-pp-${PP_SIZE}"
if [[ $EP_SIZE -gt 1 ]]; then
    NAME="${NAME}-ep-${EP_SIZE}-mlc-${MLC}-cap-${MOE_TRAIN_CAP_FACTOR}-drop-${MOE_DROP_TOKEN}"
fi
if [ "${CL_ENABLED}" = "true" ]; then
    NAME="${NAME}-cl-${CL_START_SEQLEN}-${CL_STEP}"
fi

OUTPUT_BASEPATH=$DIR/output
mkdir -p "${OUTPUT_BASEPATH}/tensorboard/"
mkdir -p "${OUTPUT_BASEPATH}/checkpoint/"
mkdir -p "${OUTPUT_BASEPATH}/log/"
TENSORBOARD_DIR="${OUTPUT_BASEPATH}/tensorboard/${NAME}_${host}_${current_time}"
mkdir -p ${TENSORBOARD_DIR} 
## Note that for MoE model with billion-scale base model, the checkpoint can be
## as large as TB-scale which normal NFS cannot handle efficiently.
CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/${NAME}"

# USE_INTERNAL_DATA="true"
USE_INTERNAL_DATA="false"

if [ "${USE_INTERNAL_DATA}" = "true" ]; then
    ## The internal data is only accessible within Microsoft
    ## For cluster Azure-EastUS-V100-32GB-4, Azure-WestUS3-A100
    # BASE_DATA_PATH=/vc_data/Megatron-LM/data
    # DATA_HOME="/vc_data/pile-cc1-cc2-shuf"
    ## For cluster Lab-RR1-V100
    BASE_DATA_PATH=/data/Megatron-LM/data
    DATA_HOME="/turing-ssd/users/conglli/data/pile-cc1-cc2-shuf"
    ## For cluster Azure-CentralUS-A100
    # BASE_DATA_PATH=/data/Megatron-LM/data
    # DATA_HOME=/vc_data_1/users/amawa/blended

    VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
    MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt
    ARX="${DATA_HOME}/ArXiv_ftfy_cleaned_id_shuf_text_document"
    BC2="${DATA_HOME}/BookCorpus2_ftfy_cleaned_id_shuf_text_document"
    B3="${DATA_HOME}/Books3_ftfy_cleaned_id_shuf_text_document"
    CC2020="${DATA_HOME}/CC-2020-50_id_cleaned_shuf_text_document"
    CC2021="${DATA_HOME}/CC-2021-04_id_cleaned_shuf_text_document"
    GIT="${DATA_HOME}/Github_ftfy_id_shuf_text_document"
    GUT="${DATA_HOME}/Gutenberg_PG-19_ftfy_cleaned_id_cleaned_shuf_text_document"
    NIH="${DATA_HOME}/NIH_ExPorter_ftfy_id_shuf_text_document"
    OWT2="${DATA_HOME}/OpenWebText2_ftfy_cleaned_id_shuf_text_document"
    PCC="${DATA_HOME}/Pile-CC_id_cleaned_shuf_text_document"
    PM="${DATA_HOME}/PubMed_Abstracts_ftfy_id_shuf_text_document"
    RN="${DATA_HOME}/rn_dedup_shuf_cleaned_0.7_cleaned_shuf_text_document"
    SE="${DATA_HOME}/StackExchange_ftfy_id_shuf_text_document"
    ST="${DATA_HOME}/stories_dedup0.7_shuf_cleaned_shuf_text_document"
    WIK="${DATA_HOME}/Wikipedia_en_ftfy_id_shuf_text_document"
    DATA_BLEND="0.14336 ${B3} 0.08962 ${RN} 0.19336 ${OWT2} 0.05689 ${SE} \
    0.00859 ${ST} 0.02897 ${PM} 0.04771 ${WIK} 0.00873 ${GUT} 0.01007 ${BC2} \
    0.00208 ${NIH} 0.13017 ${CC2020} 0.09446 ${PCC} 0.15652 ${CC2021} \
    0.01359 ${ARX} 0.01588 ${GIT}"
else
    VOCAB_PATH=/home/ubuntu/frameworks/Megatron-DeepSpeed/gpt2-vocab.json
    MERGE_PATH=/home/ubuntu/frameworks/Megatron-DeepSpeed/gpt2-merges.txt
    # Public the Pile dataset, can be downloaded at https://mystic.the-eye.eu/public/AI/pile_neox/
    DATA_BLEND=/home/ubuntu/frameworks/Megatron-DeepSpeed/dataset/BookCorpusDataset/BookCorpusDataset_text_document
fi
###############################################################################
data_options=" \
         --vocab-file ${VOCAB_PATH} \
         --merge-file ${MERGE_PATH} \
         --data-path ${DATA_BLEND} \
         --data-impl mmap"

megatron_options=" \
        --override-lr-scheduler \
        --adam-beta1 0.9 \
        --adam-beta2 0.95 \
        --tensor-model-parallel-size ${MP_SIZE} \
        --moe-expert-parallel-size ${EP_PARALLEL_SIZE} \
        --num-experts ${EP_SIZE} \
        --moe-loss-coeff ${MLC} \
        --moe-train-capacity-factor ${MOE_TRAIN_CAP_FACTOR} \
        --moe-eval-capacity-factor ${MOE_EVAL_CAP_FACTOR} \
        --moe-min-capacity ${MOE_MIN_CAP} \
        --init-method-std ${INIT_STD} \
        --lr-decay-tokens ${LR_DECAY_TOKENS} \
        --lr-warmup-tokens ${WARMUP_TOKENS} \
        --micro-batch-size ${BATCH_SIZE} \
        --exit-duration-in-mins ${EXIT_DURATION} \
        --global-batch-size ${GLOBAL_BATCH_SIZE} \
        --num-layers ${NUM_LAYERS} \
        --hidden-size ${HIDDEN_SIZE} \
        --num-attention-heads ${NUM_ATTN_HEADS} \
        --seq-length ${SEQ_LEN} \
        --max-position-embeddings ${SEQ_LEN} \
        --train-tokens ${TRAIN_TOKENS} \
        --train-iters ${TRAIN_ITERS} \
        --lr ${LR} \
        --min-lr ${MIN_LR} \
        --lr-decay-style cosine \
        --split 98,2,0 \
        --log-interval ${LOG_INTERVAL} \
        --eval-interval ${EVAL_INTERVAL} \
        --eval-iters ${EVAL_ITERS} \
        --save-interval ${SAVE_INTERVAL} \
        --weight-decay 0.1 \
        --clip-grad 1.0 \
        --hysteresis 2 \
        --num-workers 0 \
        --fp16 \
        --load ${CHECKPOINT_PATH} \
        --tensorboard-queue-size 1 \
        --log-timers-to-tensorboard \
        --log-batch-size-to-tensorboard \
        --log-validation-ppl-to-tensorboard \
        --tensorboard-dir ${TENSORBOARD_DIR}"

        # --save ${CHECKPOINT_PATH} \
if [ "${ACTIVATION_CHECKPOINT}" = "true" ]; then
megatron_options="${megatron_options} \
        --checkpoint-activations"
fi

if [[ $EP_SIZE -gt 1 ]]; then
megatron_options="${megatron_options} \
        --create-moe-param-group"
fi

if [ "${MOE_DROP_TOKEN}" = "false" ]; then
megatron_options="${megatron_options} \
        --disable-moe-token-dropping"
fi

template_json="ds_config_gpt_TEMPLATE.json"
config_json="ds_config_gpt_${NAME}.json"
sed "s/CONFIG_BATCH_SIZE/${GLOBAL_BATCH_SIZE}/" ${template_json} \
    | sed "s/CONFIG_MBSIZE/${BATCH_SIZE}/" \
    | sed "s/LOG_INTERVAL/${LOG_INTERVAL}/" \
    | sed "s/ZERO_STAGE/0/" \
    | sed "s/PRESCALE_GRAD/true/" \
    | sed "s/CONFIG_FP16_ENABLED/true/" \
    | sed "s/CONFIG_BF16_ENABLED/false/" \
    | sed "s/CONFIG_CL_ENABLED/${CL_ENABLED}/" \
    | sed "s/CONFIG_CL_MIN/${CL_START_SEQLEN}/" \
    | sed "s/CONFIG_CL_MAX/${SEQ_LEN}/" \
    | sed "s/CONFIG_CL_DURATION/${CL_STEP}/" \
      > ${config_json}

deepspeed_options=" \
            --deepspeed \
            --deepspeed_config ${config_json} \
            --pipeline-model-parallel-size ${PP_SIZE}"

# Currently MoE is not compatible with pipeline parallel
if [[ $EP_SIZE -gt 1 ]]; then
deepspeed_options="${deepspeed_options} \
        --no-pipeline-parallel"
fi

if [ "${ACTIVATION_CHECKPOINT}" = "true" ]; then
deepspeed_options="${deepspeed_options} \
        --deepspeed-activation-checkpointing"
fi

run_cmd="deepspeed ${DIR}/../../pretrain_gpt.py ${megatron_options} ${data_options} ${deepspeed_options}
# &> ${OUTPUT_BASEPATH}/log/${NAME}_${host}_${current_time}.log"
echo ${run_cmd}
eval ${run_cmd}
set +x

contents of main function in pretrain_gpt.py:

if __name__ == "__main__":
    git_ds_info()

    from megatron.training import initialize_megatron, get_model
    from megatron import get_args, initialize_megatron
    import deepspeed

    initialize_megatron(extra_args_provider=None, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})

    model = get_model(model_provider)
    model = deepspeed.init_inference(model[0])

With these files, I executed

$ bash ds_pretrain_gpt_1.3B_MoE128.sh

Expected behavior I expect InferenceEngine for the model is successfully created with init_inference.

ds_report output

/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/setuptools/_distutils/core.py
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  using untested triton version (1.1.1), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/torch']
torch version .................... 1.12.0
torch cuda version ............... 11.6
torch hip version ................ None
nvcc version ..................... 11.6
deepspeed install path ........... ['/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.7.0, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.6

Screenshots Following is the error message:

Traceback (most recent call last):
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py", line 297, in <module>
    model = deepspeed.init_inference(model[0])
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/__init__.py", line 292, in init_inference
    engine = InferenceEngine(model,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/inference/engine.py", line 136, in __init__
    self._apply_injection_policy(
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/inference/engine.py", line 329, in _apply_injection_policy
    replace_transformer_layer(
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 780, in replace_transformer_layer
    replaced_module = replace_module(model=model,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 962, in replace_module
    replaced_module, _ = _replace_module(model, policy)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 989, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 989, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 989, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 979, in _replace_module
    replaced_module = policies[child.__class__][0](child,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 776, in replace_fn
    new_module = replace_wo_policy(child, _policy)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 755, in replace_wo_policy
    return _replace_module(module)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 752, in _replace_module
    _replace_module(child, name)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 752, in _replace_module
    _replace_module(child, name)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 752, in _replace_module
    _replace_module(child, name)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 747, in _replace_module
    linear_policies[child.__class__](child,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 629, in _replace
    if name in all_reduce_linears:
TypeError: argument of type 'ABCMeta' is not iterable
Traceback (most recent call last):
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py", line 297, in <module>
    model = deepspeed.init_inference(model[0])
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/__init__.py", line 292, in init_inference
    engine = InferenceEngine(model,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/inference/engine.py", line 136, in __init__
    self._apply_injection_policy(
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/inference/engine.py", line 329, in _apply_injection_policy
    replace_transformer_layer(
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 780, in replace_transformer_layer
    replaced_module = replace_module(model=model,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 962, in replace_module
    replaced_module, _ = _replace_module(model, policy)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 989, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 989, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 989, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 979, in _replace_module
    replaced_module = policies[child.__class__][0](child,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 776, in replace_fn
    new_module = replace_wo_policy(child, _policy)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 755, in replace_wo_policy
    return _replace_module(module)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 752, in _replace_module
    _replace_module(child, name)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 752, in _replace_module
    _replace_module(child, name)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 752, in _replace_module
    _replace_module(child, name)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 747, in _replace_module
    linear_policies[child.__class__](child,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/module_inject/replace_module.py", line 629, in _replace
    if name in all_reduce_linears:
TypeError: argument of type 'ABCMeta' is not iterable
[2022-08-04 13:46:55,904] [INFO] [launch.py:286:sigkill_handler] Killing subprocess 402082
[2022-08-04 13:46:55,917] [INFO] [launch.py:286:sigkill_handler] Killing subprocess 402083
[2022-08-04 13:46:55,917] [ERROR] [launch.py:292:sigkill_handler] ['/home/ubuntu/miniconda3/envs/tutel/bin/python3.9', '-u', '/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py', '--local_rank=1', '--override-lr-scheduler', '--adam-beta1', '0.9', '--adam-beta2', '0.95', '--tensor-model-parallel-size', '1', '--moe-expert-parallel-size', '2', '--num-experts', '2', '--moe-loss-coeff', '0.01', '--moe-train-capacity-factor', '1.0', '--moe-eval-capacity-factor', '1.0', '--moe-min-capacity', '4', '--init-method-std', '0.014', '--lr-decay-tokens', '300000000000', '--lr-warmup-tokens', '375000000', '--micro-batch-size', '8', '--exit-duration-in-mins', '300', '--global-batch-size', '256', '--num-layers', '12', '--hidden-size', '768', '--num-attention-heads', '12', '--seq-length', '2048', '--max-position-embeddings', '2048', '--train-tokens', '300000000000', '--train-iters', '1716613', '--lr', '1.2e-4', '--min-lr', '1.0e-6', '--lr-decay-style', 'cosine', '--split', '98,2,0', '--log-interval', '10', '--eval-interval', '100', '--eval-iters', '10', '--save-interval', '10000', '--weight-decay', '0.1', '--clip-grad', '1.0', '--hysteresis', '2', '--num-workers', '0', '--fp16', '--load', '/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/output/checkpoint/gpt-0.125B-lr-1.2e-4-minlr-1.0e-6-bs-256-gpus-2-mp-1-pp-1-ep-2-mlc-0.01-cap-1.0-drop-true', '--tensorboard-queue-size', '1', '--log-timers-to-tensorboard', '--log-batch-size-to-tensorboard', '--log-validation-ppl-to-tensorboard', '--tensorboard-dir', '/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/output/tensorboard/gpt-0.125B-lr-1.2e-4-minlr-1.0e-6-bs-256-gpus-2-mp-1-pp-1-ep-2-mlc-0.01-cap-1.0-drop-true_f09_2022.08.04-13.46.47', '--checkpoint-activations', '--create-moe-param-group', '--vocab-file', '/home/ubuntu/frameworks/Megatron-DeepSpeed/gpt2-vocab.json', '--merge-file', '/home/ubuntu/frameworks/Megatron-DeepSpeed/gpt2-merges.txt', '--data-path', '/home/ubuntu/frameworks/Megatron-DeepSpeed/dataset/BookCorpusDataset/BookCorpusDataset_text_document', '--data-impl', 'mmap', '--deepspeed', '--deepspeed_config', 'ds_config_gpt_gpt-0.125B-lr-1.2e-4-minlr-1.0e-6-bs-256-gpus-2-mp-1-pp-1-ep-2-mlc-0.01-cap-1.0-drop-true.json', '--pipeline-model-parallel-size', '1', '--no-pipeline-parallel', '--deepspeed-activation-checkpointing'] exits with return code = 1

System info (please complete the following information):

Launcher context I tested with modifed files and command written above.

Gabriel4256 commented 2 years ago

I managed to create InferenceEngine by adding some configs, but other problem occurs when running forward pass of it. Following is the revised pretrain_gpt.py:

from megatron.training import initialize_megatron, get_model, forward_backward_pipelining_with_interleaving, forward_backward_pipelining_without_interleaving, forward_backward_no_pipelining
from megatron import get_args, initialize_megatron
import deepspeed

initialize_megatron(extra_args_provider=None, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
model = get_model(model_provider)
args = get_args()

model_engine = deepspeed.init_inference(
    model[0],
    moe_experts=args.num_experts,
    replace_with_kernel_inject=True,
    dtype = torch.half if args.fp16 else None,
    moe=True,
)

model = model_engine.module

args.iteration = 0

train_data_iterator, valid_data_iterator, test_data_iterator \
    = build_train_valid_test_data_iterators(
        train_valid_test_datasets_provider)

forward_step(test_data_iterator, model, None)

And this leads to another error as follows:

Traceback (most recent call last):
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py", line 391, in <module>
    forward_step(test_data_iterator, model, None)
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py", line 211, in forward_step
Traceback (most recent call last):
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py", line 391, in <module>
    output_tensor, *other_losses = model(tokens, position_ids, attention_mask,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
        result = forward_call(*input, **kwargs)forward_step(test_data_iterator, model, None)

  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/gpt_model.py", line 120, in forward
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py", line 211, in forward_step
    lm_output, *moe_losses = self.language_model(
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    output_tensor, *other_losses = model(tokens, position_ids, attention_mask,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/language_model.py", line 389, in forward
    result = forward_call(*input, **kwargs)
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/gpt_model.py", line 120, in forward
        encoder_output, *moe_losses = self.encoder(encoder_input,lm_output, *moe_losses = self.language_model(

  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/transformer.py", line 769, in forward
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/language_model.py", line 389, in forward
    encoder_output, *moe_losses = self.encoder(encoder_input,
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    hidden_states, moe_losses = self._checkpointed_forward(hidden_states,
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/transformer.py", line 719, in _checkpointed_forward
        hidden_states, *local_moe_losses = mpu.checkpoint(return forward_call(*input, **kwargs)

  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 748, in checkpoint
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/transformer.py", line 769, in forward
        hidden_states, moe_losses = self._checkpointed_forward(hidden_states,CheckpointFunction.apply(function, all_outputs, *args)

  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/transformer.py", line 719, in _checkpointed_forward
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 582, in forward
    outputs = run_function(*inputs_cuda)
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/transformer.py", line 709, in custom_forward
    hidden_states, *local_moe_losses = mpu.checkpoint(
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 748, in checkpoint
    x_, moe_loss = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
ValueError: not enough values to unpack (expected 2, got 1)
    CheckpointFunction.apply(function, all_outputs, *args)
  File "/home/ubuntu/miniconda3/envs/tutel/lib/python3.9/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 582, in forward
    outputs = run_function(*inputs_cuda)
  File "/home/ubuntu/frameworks/Megatron-DeepSpeed/megatron/model/transformer.py", line 709, in custom_forward
    x_, moe_loss = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
ValueError: not enough values to unpack (expected 2, got 1)
[2022-08-07 09:25:24,352] [INFO] [launch.py:286:sigkill_handler] Killing subprocess 484730
[2022-08-07 09:25:24,360] [INFO] [launch.py:286:sigkill_handler] Killing subprocess 484731
[2022-08-07 09:25:24,360] [ERROR] [launch.py:292:sigkill_handler] ['/home/ubuntu/miniconda3/envs/tutel/bin/python3.9', '-u', '/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/../../pretrain_gpt.py', '--local_rank=1', '--override-lr-scheduler', '--adam-beta1', '0.9', '--adam-beta2', '0.95', '--tensor-model-parallel-size', '1', '--moe-expert-parallel-size', '2', '--num-experts', '2', '--moe-loss-coeff', '0.01', '--moe-train-capacity-factor', '1.0', '--moe-eval-capacity-factor', '1.0', '--moe-min-capacity', '4', '--init-method-std', '0.014', '--lr-decay-tokens', '300000000000', '--lr-warmup-tokens', '375000000', '--micro-batch-size', '8', '--exit-duration-in-mins', '5', '--global-batch-size', '256', '--num-layers', '12', '--hidden-size', '768', '--num-attention-heads', '12', '--seq-length', '2048', '--max-position-embeddings', '2048', '--train-tokens', '300000000000', '--train-iters', '0', '--lr', '1.2e-4', '--min-lr', '1.0e-6', '--lr-decay-style', 'cosine', '--split', '94,3,3', '--log-interval', '10', '--eval-interval', '100', '--eval-iters', '10', '--save-interval', '10', '--weight-decay', '0.1', '--clip-grad', '1.0', '--hysteresis', '2', '--num-workers', '0', '--fp16', '--load', '/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/output/checkpoint/gpt-0.125B-lr-1.2e-4-minlr-1.0e-6-bs-256-gpus-2-mp-1-pp-1-ep-2-mlc-0.01-cap-1.0-drop-true', '--save', '/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/output/checkpoint/gpt-0.125B-lr-1.2e-4-minlr-1.0e-6-bs-256-gpus-2-mp-1-pp-1-ep-2-mlc-0.01-cap-1.0-drop-true', '--tensorboard-queue-size', '1', '--log-timers-to-tensorboard', '--log-batch-size-to-tensorboard', '--log-validation-ppl-to-tensorboard', '--tensorboard-dir', '/home/ubuntu/frameworks/Megatron-DeepSpeed/examples/MoE/output/tensorboard/gpt-0.125B-lr-1.2e-4-minlr-1.0e-6-bs-256-gpus-2-mp-1-pp-1-ep-2-mlc-0.01-cap-1.0-drop-true_f09_2022.08.07-09.25.13', '--inference', '--checkpoint-activations', '--create-moe-param-group', '--vocab-file', '/home/ubuntu/frameworks/Megatron-DeepSpeed/gpt2-vocab.json', '--merge-file', '/home/ubuntu/frameworks/Megatron-DeepSpeed/gpt2-merges.txt', '--data-path', '/home/ubuntu/frameworks/Megatron-DeepSpeed/dataset/BookCorpusDataset/BookCorpusDataset_text_document', '--data-impl', 'mmap', '--deepspeed', '--deepspeed_config', 'ds_config_gpt_gpt-0.125B-lr-1.2e-4-minlr-1.0e-6-bs-256-gpus-2-mp-1-pp-1-ep-2-mlc-0.01-cap-1.0-drop-true.json', '--pipeline-model-parallel-size', '1', '--no-pipeline-parallel', '--deepspeed-activation-checkpointing'] exits with return code = 1
awan-10 commented 2 years ago

@Gabriel4256 -- please look at this example: https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples/generate_text.sh

For inference, we don't use the pretrain_gpt.py as an entry-point. Please try the above text generation scenario that uses deepspeed inference. If you run into issues with that, please share with us.

Gabriel4256 commented 2 years ago

@awan-10 Thanks for the comment. Unfortunately, I've already tried the example you shared and found it didn't work (https://github.com/microsoft/DeepSpeed/issues/2030#issuecomment-1193909540).

Gabriel4256 commented 2 years ago

I've also tried this on a machine with v100 32G * 8, but failed with almost same error. Does the script only run on A100?

jeffra commented 1 year ago

Closing to move discussion to #2030, please re-open if the core issue here is not covered in the other issue.