bigscience-workshop / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2
Other
1.31k stars 213 forks source link

About convert deepspeed to deepspeed checkpoint #338

Open henan991201 opened 2 years ago

henan991201 commented 2 years ago

In training, I used swiglu, TP=4, PP=2. I use deepspeed_to_deepspeed.py to convert the checkpoint to a TP=1, PP=1 one. When evaluating the obtained checkpoint, it is found that the accuracy is inconsistent with the previous one. (When I do not use swiglu, I can get the correct checkpoint)

mayank31398 commented 2 years ago

This is expected behaviour @henan991201 When you reshard the model, the order of operations change. And floating point operations are not associative Refer to https://github.com/pytorch/pytorch/issues/76232

mayank31398 commented 2 years ago

But i am not sure why this is happening with only swiglu

henan991201 commented 2 years ago

In my experiment, if I use swiglu, I also find that I need to set TP=4 and PP=2 in the evaluation script in order to get results, otherwise, I will get results that are totally wrong (like randomly selected). When I convert the checkpoint from TP=4 PP=2 to TP=1 PP=1, if I use swiglu, I will get wrong results after evaluation, but if I use gelu, I can get correct results. @mayank31398

henan991201 commented 2 years ago

GPUS_PER_NODE=8 NNODES=1

TP_SIZE=4 PP_SIZE=2

MICRO_BATCH_SIZE=16
GLOBAL_BATCH_SIZE=256

NLAYERS=24 NHIDDEN=1024 NHEADS=16 SEQ_LEN=2048

SAVE_INTERVAL=500

TRAIN_SAMPLES=220_000_000
LR_DECAY_SAMPLES=200_000_000 LR_WARMUP_SAMPLES=183_105

OPTIMIZER_ARGS=" \ --optimizer adam \ --adam-beta1 0.9 \ --adam-beta2 0.95 \ --adam-eps 1e-8 \ --lr 3.0e-4 \ --min-lr 1e-5 \ --lr-decay-style cosine \ --lr-decay-samples $LR_DECAY_SAMPLES \ --lr-warmup-samples $LR_WARMUP_SAMPLES \ --clip-grad 1.0 \ --weight-decay 1e-1 \ "

EXIT_OPTS=" \ --exit-duration-in-mins 599000 \ "

GPT_ARGS=" \ --pp-partition-method type:transformer|embedding \ --num-layers $NLAYERS \ --hidden-size $NHIDDEN \ --num-attention-heads $NHEADS \ --seq-length $SEQ_LEN \ --max-position-embeddings $SEQ_LEN \ --micro-batch-size $MICRO_BATCH_SIZE \ --global-batch-size $GLOBAL_BATCH_SIZE \ --train-samples $TRAIN_SAMPLES \ --tokenizer-type PretrainedFromHF \ --tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \ --init-method-std 0.0048 \ --fp16 \ --seed 42 \ --no-bias-gelu-fusion \ --glu-activation swiglu \ --position-embedding-type alibi \ --checkpoint-activations \ --abort-on-unmet-fused-kernel-constraints \ $OPTIMIZER_ARGS \ $EXIT_OPTS \ " OUTPUT_ARGS=" \ --log-interval 1 \ --save-interval $SAVE_INTERVAL \ --eval-interval 500 \ --eval-iters 1 \ --tensorboard-dir $TENSORBOARD_PATH \ --tensorboard-queue-size 5 \ --log-timers-to-tensorboard \ --log-batch-size-to-tensorboard \ --log-validation-ppl-to-tensorboard \ "

ZERO_STAGE=0 mkdir -p ds_config config_json="./ds_config/ds_config.$SLURM_JOB_ID.json"

cat < $config_json { "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, "train_batch_size": $GLOBAL_BATCH_SIZE, "gradient_clipping": 1.0, "zero_optimization": { "stage": $ZERO_STAGE }, "fp16": { "enabled": true, "loss_scale": 0, "loss_scale_window": 500, "hysteresis": 2, "min_loss_scale": 1, "initial_scale_power": 12 }, "steps_per_print": 2000, "wall_clock_breakdown": false } EOT

DEEPSPEED_ARGS=" \ --deepspeed \ --deepspeed_config ${config_json} \ --zero-stage ${ZERO_STAGE} \ --deepspeed-activation-checkpointing \ "

export LAUNCHER="python -u -m torch.distributed.launch \ --nproc_per_node $GPUS_PER_NODE \ --nnodes $NNODES \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT \ "

export CMD=" \ pwd/pretrain_gpt.py \ --tensor-model-parallel-size $TP_SIZE \ --pipeline-model-parallel-size $PP_SIZE \ $GPT_ARGS \ $OUTPUT_ARGS \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ --data-path $DATA_PATH \ --split 949,50,1 \ --data-impl mmap \ --distributed-backend nccl \ $DEEPSPEED_ARGS \ "