epfLLM / Megatron-LLM

distributed trainer for LLMs
Other
504 stars 73 forks source link

Add Mistral Model #88

Closed xingyaoww closed 7 months ago

xingyaoww commented 8 months ago

https://github.com/epfLLM/Megatron-LLM/issues/76#issue-1947436642

A preliminary Mistral Implementation which relies on FlashAttention for windowed attention.

Script for model conversion:

#!/bin/bash

# download from huggingface: https://huggingface.co/mistralai/Mistral-7B-v0.1
RAW_MODEL_WEIGHT_DIR=/models/Mistral-7B-v0.1/
OUTPUT_DIR=data/models/raw/Mistral-7b-megatron

python Megatron-LLM/weights_conversion/hf_to_megatron.py mistral \
    --size=7 \
    --out=$OUTPUT_DIR \
    --model-path=$RAW_MODEL_WEIGHT_DIR \

Script for verify_correctness:

# arguments required by `torchrun`
DISTRIBUTED_ARGS="--nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 8000"
MISTRAL_ARGS="--use_rms_norm --glu_activation swiglu --no_tie_embed_logits --no_new_tokens --layernorm_epsilon 1e-5 --use_flash_attn --bf16 --seq_length 512"
COMMON_ARGS="--hidden_dropout 0.0 --attention_dropout 0.0 --no_bias_gelu_fusion --no_bias_dropout_fusion"

HF_MODEL_DIR=/models/Mistral-7B-v0.1
MODEL_CKPT=data/models/raw/Mistral-7b-megatron
TOKENIZER_PATH=data/models/raw/Mistral-7b-megatron/tokenizer.model
DATA_PATH=data/megatron_format/starcoder_example/data_text_document

torchrun $DISTRIBUTED_ARGS Megatron-LLM/verify_correctness.py \
    --model_name=mistral \
    --model_size=7 \
    --load=$MODEL_CKPT \
    --data_path=$DATA_PATH \
    --tokenizer_type=SentencePieceTokenizer \
    --vocab_file=$TOKENIZER_PATH \
    --huggingface_cache=$HF_MODEL_DIR \
    --huggingface_device=cuda:1 \
    $COMMON_ARGS $MISTRAL_ARGS

The starcoder_example dataset can be prepared following the quickstart guide.

Results (average loss diff is below 0.1 for bf16).

Iteration 0...
Max absoulute error in the logits: max=0.928502, avg=0.024152
Abs loss error: 0.000518 Our loss: 0.919, theirs: 0.919
Iteration 1...
Max absoulute error in the logits: max=1.135055, avg=0.017372
Abs loss error: 0.000350 Our loss: 1.351, theirs: 1.352
Iteration 2...
Max absoulute error in the logits: max=0.439639, avg=0.016328
Abs loss error: 0.001454 Our loss: 0.596, theirs: 0.598
Iteration 3...
Max absoulute error in the logits: max=1.323759, avg=0.016411
Abs loss error: 0.004381 Our loss: 1.550, theirs: 1.554
Iteration 4...
Max absoulute error in the logits: max=0.709203, avg=0.016201
Abs loss error: 0.002717 Our loss: 1.754, theirs: 1.757
Iteration 5...
Max absoulute error in the logits: max=0.776079, avg=0.017546
Abs loss error: 0.001265 Our loss: 1.462, theirs: 1.463
Iteration 6...
Max absoulute error in the logits: max=0.925026, avg=0.015901
Abs loss error: 0.002268 Our loss: 1.208, theirs: 1.211
Iteration 7...
Max absoulute error in the logits: max=1.482982, avg=0.018049
Abs loss error: 0.001490 Our loss: 1.300, theirs: 1.302
Iteration 8...
Max absoulute error in the logits: max=2.211723, avg=0.019359
Abs loss error: 0.001245 Our loss: 0.623, theirs: 0.625
Iteration 9...
Max absoulute error in the logits: max=2.620514, avg=0.017941
Abs loss error: 0.001800 Our loss: 0.899, theirs: 0.897