InternLM / xtuner

An efficient, flexible and full-featured toolkit for fine-tuning LLM (InternLM2, Llama3, Phi3, Qwen, Mistral, ...)
https://xtuner.readthedocs.io/zh-cn/latest/
Apache License 2.0
3.62k stars 296 forks source link

多卡训练报错 OOM:Mixtral 8x7b #521

Closed aiyinyuedejustin closed 4 months ago

aiyinyuedejustin commented 5 months ago

想要在8张40G 的A100 上微调Mixtral 8x7b. 发现OOM,之前在单卡A100 80G 上用相同的配置都没问题呀

使用的命令是: NPROC_PER_NODE=8 xtuner train mixtral_config.py -- deepspeed deepspeed_zero2 > output.txt 2>&1

从sample output之后开始报错:

image image image

配置:

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
# from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
from xtuner.dataset.map_fns import template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
                                 VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
pretrained_model_name_or_path = '/root/july-mixtral/Mixtral-8x7B-Instruct-v0.1'
use_varlen_attn = False

# Data
data_path = '/root/july-mixtral/Mixtral-xtuner/paper_150_single_round_xtuner.json' #'timdettmers/openassistant-guanaco'
prompt_template = PROMPT_TEMPLATE.mixtral
max_length = 12288
pack_to_max_length = True

# Scheduler & Optimizer
batch_size = 2  # per_device
accumulative_counts = 16
dataloader_num_workers = 0 #
max_epochs = 1  #3 --- 改
optim_type = AdamW 
lr = 3e-5
betas = (0.9, 0.95)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.025

# Save
save_steps = 500
save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 50
SYSTEM = ''
evaluation_inputs = [ 

                     """ 
                     """]
#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    trust_remote_code=True,
    padding_side='right')

model = dict(
    type=SupervisedFinetune,
    use_varlen_attn=use_varlen_attn,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        quantization_config=dict(
            type=BitsAndBytesConfig,
            load_in_4bit=True,
            load_in_8bit=False,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4')),
    lora=dict(
        type=LoraConfig,
        r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=[
            'q_proj', 'k_proj', 'v_proj', 'o_proj', 'w1', 'w2', 'w3'
        ], #这里还可以加lm_head之类的
        bias='none',
        task_type='CAUSAL_LM'))

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
train_dataset = dict(
    type=process_hf_dataset,
    # dataset=dict(type=load_dataset, path=data_path),
     dataset=dict(
       type=load_dataset, path='json', data_files=dict(train=data_path)),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=None,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length,
    use_varlen_attn=use_varlen_attn)

train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=train_dataset,
    sampler=dict(type=DefaultSampler, shuffle=True),
    collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(
        type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale='dynamic',
    dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        end=max_epochs,
        convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        system=SYSTEM,
        prompt_template=prompt_template)
]

if use_varlen_attn:
    custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 10 iterations.
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per `save_steps`.
    checkpoint=dict(
        type=CheckpointHook,
        save_optimizer = False, # 不保存@@@!!!!!!!!
        by_epoch=False,
        interval=save_steps,
        max_keep_ckpts=save_total_limit),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)
aiyinyuedejustin commented 5 months ago

使用deepspeed_zero2_offload依然会OOM,我的batch size已经是1了,是否有别的可调整的地方呢?

LZHgrla commented 5 months ago

@aiyinyuedejustin 就是单纯的在40GB下会OOM,一种可能的方案是使用 LoRA 训练 或 全参数训练 + ZeRO3,可以实际试一下能不能跑起来

aiyinyuedejustin commented 5 months ago

@aiyinyuedejustin 就是单纯的在40GB下会OOM,一种可能的方案是使用 LoRA 训练 或 全参数训练 + ZeRO3,可以实际试一下能不能跑起来

请问我应该这样设置config吗,需不需要load_int_8bit之类?


# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
# from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
from xtuner.dataset.map_fns import template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
                                 VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
pretrained_model_name_or_path = '/mnt/llm_folder/Mixtral-87-onlysafetensors'
use_varlen_attn = False

# Data
# data_path = '/mnt/llm_folder/July-Mixtral-files/paper_single_round_1w5.json' #'timdettmers/openassistant-guanaco'
data_path = '/mnt/llm_folder/July-Mixtral-files/paper_full_single_5kround_xtuner.json'
prompt_template = PROMPT_TEMPLATE.mixtral
max_length = 11000 #12288
pack_to_max_length = False #True!!!!!

# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 25
dataloader_num_workers = 0 #
max_epochs = 2 #3 --- 改
optim_type = AdamW 
lr =2e-4 # 3e-5 !!!!
betas = (0.9, 0.95)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.03

# Save 33 88 

save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = "You are a professional machine learning conference reviewer who reviews a given paper and consider 4 criteria: ** importance and novelty **, ** potential reasons for acceptance **, ** potential reasons for rejection **, and ** suggestions for improvement **. You must output 4 criteria, with some subpoints for each criteria. The given paper is as follows:"
evaluation_inputs = [ " " ]
#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    trust_remote_code=True,
    padding_side='right')

model = dict(
    type=SupervisedFinetune,
    use_varlen_attn=use_varlen_attn,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16),

    lora=dict(
        type=LoraConfig,
        r=16,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=[
            'q_proj', 'k_proj', 'v_proj', 'o_proj', 'w1', 'w2' 
        ], #这里还可以加lm_head之类的 'w3'
        bias='none',
        task_type='CAUSAL_LM'))

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
train_dataset = dict(
    type=process_hf_dataset,
    # dataset=dict(type=load_dataset, path=data_path),
     dataset=dict(
       type=load_dataset, path='json', data_files=dict(train=data_path)),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=None,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length,
    use_varlen_attn=use_varlen_attn)

train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=train_dataset,
    sampler=dict(type=DefaultSampler, shuffle=True),
    collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(
        type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale='dynamic',
    dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        end=max_epochs,
        convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        system=SYSTEM,
        prompt_template=prompt_template)
]

if use_varlen_attn:
    custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 10 iterations.
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per `save_steps`.
    checkpoint=dict(
        type=CheckpointHook,
        save_optimizer = False, # 不保存@@@!!!!!!!!
        by_epoch=False,
        interval=save_steps,
        max_keep_ckpts=save_total_limit),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)
LZHgrla commented 5 months ago

@aiyinyuedejustin 这就是LoRA训练,可以试一下(虽然不一定能跑起来)。 同时 max_length 太长了,40GB即使能跑也肯定跑不了这么长的长度

aiyinyuedejustin commented 5 months ago

@aiyinyuedejustin 这就是LoRA训练,可以试一下(虽然不一定能跑起来)。 同时 max_length 太长了,40GB即使能跑也肯定跑不了这么长的长度

谢谢回复,我想问一下,如果max_len实在不能在降低了,如果怕oom,是否应该accumulative_counts 变小,lora r变小,lora alpha也变小?还有其他跟显存相关的参数不?

又或者我当前不爆显存,为了获得更好的sft效果,是不是主要增加lora r即可? lora r 跟alpha按经验是不是 lora r *2 =alpha嘞

LZHgrla commented 5 months ago

@aiyinyuedejustin 这就是LoRA训练,可以试一下(虽然不一定能跑起来)。 同时 max_length 太长了,40GB即使能跑也肯定跑不了这么长的长度

谢谢回复,我想问一下,如果max_len实在不能在降低了,如果怕oom,是否应该accumulative_counts 变小,lora r变小,lora alpha也变小?还有其他跟显存相关的参数不?

又或者我当前不爆显存,为了获得更好的sft效果,是不是主要增加lora r即可? lora r 跟alpha按经验是不是 lora r *2 =alpha嘞

lora 的参数确实会有影响。accumulative_counts 和其他部分应该几乎没什么影响。 lora的话,alpha 也可以固定,只调整 r