InternLM / xtuner

An efficient, flexible and full-featured toolkit for fine-tuning LLM (InternLM2, Llama3, Phi3, Qwen, Mistral, ...)
https://xtuner.readthedocs.io/zh-cn/latest/
Apache License 2.0
4.01k stars 316 forks source link

Mixtral 8x7B SFT 问题 #533

Open aiyinyuedejustin opened 8 months ago

aiyinyuedejustin commented 8 months ago

您好,正在尝试微调mixtral 8x7b,但是训练一段时间后loss不再下降,输出也有些问题

使用的config如下:

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
# from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
from xtuner.dataset.map_fns import template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
                                 VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
#                          PART 1  Settings          increase lora rank!!!!!                #
#######################################################################
# Model
pretrained_model_name_or_path = '/mnt/llm_folder/Mixtral-87-onlysafetensors'
use_varlen_attn = False

# Data
data_path = '/mnt/llm_folder/July-Mixtral-files/paper_single_round_1w5.json' #'timdettmers/openassistant-guanaco'
# data_path = '/mnt/llm_folder/July-Mixtral-files/paper_full_single_5kround_xtuner.json'
prompt_template = PROMPT_TEMPLATE.mixtral
max_length = 12288
pack_to_max_length = False #True!!!!!

# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 32 # change!!!
dataloader_num_workers = 0 #
max_epochs = 3 #3 --- 改
optim_type = AdamW 
lr =1e-4 # 3e-5 !!!!
betas = (0.9, 0.95)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.06

# Save 33 88 

save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = """You are a professional machine learning conference reviewer who reviews a given paper and considers 4 criteria: [Significance and novelty], [Potential reasons for acceptance], [Potential reasons for rejection], and [Suggestions for improvement]. For each criterion, provide random number of supporting points derived from the paper's content. Use the format: '<Title of supporting point>' followed by a detailed explanation. Ensure your points are specific, actionable, and directly related to the paper's content. The paper is given as follows:"""

evaluation_inputs = [
    """[TITLE]\nYaRN: Efficient Context Window Extension of Large Language Models\n\n[ABSTRACT]\n此处省略一堆。。。。。 """]
#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    trust_remote_code=True,
    padding_side='right')

model = dict(
    type=SupervisedFinetune,
    use_varlen_attn=use_varlen_attn,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16),

    lora=dict(
        type=LoraConfig,
        r=64,
        lora_alpha=50,
        lora_dropout=0.1,
        target_modules=[
            'q_proj', 'k_proj', 'v_proj', 'o_proj', 'w1', 'w2' ,'w3'
        ], #这里还可以加lm_head之类的 'w3'
        bias='none',
        task_type='CAUSAL_LM'))

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
train_dataset = dict(
    type=process_hf_dataset,
    # dataset=dict(type=load_dataset, path=data_path),
     dataset=dict(
       type=load_dataset, path='json', data_files=dict(train=data_path)),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=None,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length,
    use_varlen_attn=use_varlen_attn)

train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=train_dataset,
    sampler=dict(type=DefaultSampler, shuffle=True),
    collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(
        type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale='dynamic',
    dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        end=max_epochs,
        convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        system=SYSTEM,
        prompt_template=prompt_template)
]

if use_varlen_attn:
    custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 10 iterations.
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per `save_steps`.
    checkpoint=dict(
        type=CheckpointHook,
        save_optimizer = False, # 不保存@@@!!!!!!!!
        by_epoch=False,
        interval=save_steps,
        max_keep_ckpts=save_total_limit),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)

训练使用8x 40G A100, Lora训练,显存没有爆。

在训练过程中,1000/5838 steps后loss就一直在1附近徘徊:

image

问题1:有什么参数可以更改吗?是不是我的lr太小了?或者batch size太小了?或许还有其他参数能更改吗

问题2: 如下图,mixtral的模板是不是错了呢,导致输出略奇怪,没有follow instruciton? 参考:https://github.com/InternLM/xtuner/blob/0b5708c49948355ae1406fb99ec6764417aee3fa/xtuner/utils/templates.py#L123 我的训练数据是自定义单轮,如下:

image

我发现训练开始前train example只是 "SOS" 开头,并没有加一个 [INST] , system结束后跟了一个[/INST]。 然后input也是用[INST]开头,然后[/INST]结尾, 然后直接就加上了output,最后以 EOS结尾, 如下:

image image

然而训练中eval的时候是 "SOS" +[INST] + system message+[/INST] +[INST]+input+[/INST]+eval_output 的形式:

image

会不会是模板错了, 不然训练开始前的train example 出现了两遍 [INST][/INST]这正常吗?且eval output的时候,理论上不应该重复input吧,应该直接输出我想要的output才对,但他还是重复了 special token和input,然后才开始output。

根据hugging facehttps://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1 https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/98

看起来需要如下格式?

<s> [INST] Instruction [/INST] Model answer</s> [INST] Follow-up instruction [/INST]

并且貌似我们的system和input是否应该合并起来作为一个总的"instruction"?而不是分开?比如应该 SOS+ [INST] system+input+[/INST] +output +EOS ?

aiyinyuedejustin commented 8 months ago

提供更多信息:

[check_data.txt](https://github.com/InternLM/xtuner/files/14807575/check_data.txt)

上面的文件是运行xtuner check-custom-data产生的。

可以看到,在dataset after adding template后,'input'键中出现了两次[INST][/INST],

格式为[INST]+我想要的system+[/INST]+[INST]+我想要的input+[/INST]。

然后就是'output'键, 然后又来了个'system'键,如下:

image image

会不会是这里的template导致的问题呢?还是说这就是xtuner要求的格式

LZHgrla commented 8 months ago

@aiyinyuedejustin 问题1:我们这边没有相关的经验,但调整lr、batchsize是一个好方法;另外,loss与实际性能也并不是强相关的 问题2:mixtral本身没有system字段的支持,xtuner通过复用user字段来支持system的功能。如果你期望完全对齐,请将system文本放置在user内,并在数据集中删除system字段。eval output时就是会“打印”出input,而非模型生成的。

另外,训练开头打印数据那个地方,是<s> You xxx [/INST],而非<s> [INST] You xxx [/INST] 有点奇怪,请问这个能稳定复现吗?

aiyinyuedejustin commented 8 months ago

@aiyinyuedejustin 问题1:我们这边没有相关的经验,但调整lr、batchsize是一个好方法;另外,loss与实际性能也并不是强相关的 问题2:mixtral本身没有system字段的支持,xtuner通过复用user字段来支持system的功能。如果你期望完全对齐,请将system文本放置在user内,并在数据集中删除system字段。eval output时就是会“打印”出input,而非模型生成的。

另外,训练开头打印数据那个地方,是<s> You xxx [/INST],而非<s> [INST] You xxx [/INST] 有点奇怪,请问这个能稳定复现吗?

问题1: 谢谢

问题2: 能稳定复现,比如下面是我刚刚调整了config之后,全量微调的:

image
LZHgrla commented 8 months ago

@aiyinyuedejustin https://github.com/InternLM/xtuner/blob/0b5708c49948355ae1406fb99ec6764417aee3fa/xtuner/dataset/huggingface.py#L182-L184 单卡训练,在这几行的后面打个断点,检查一下 input 中的这个初始的[INST]是否正常吧