InternLM / xtuner

An efficient, flexible and full-featured toolkit for fine-tuning LLM (InternLM2, Llama3, Phi3, Qwen, Mistral, ...)
https://xtuner.readthedocs.io/zh-cn/latest/
Apache License 2.0
3.76k stars 302 forks source link

关于Qwen1.5 32B-Chat 训练的问题 #793

Open wgs97 opened 2 months ago

wgs97 commented 2 months ago

很奇怪的问题是,在8*A100(80G)上无论我如何设置max-seq,从16000降到200,始终都会OOM。 如下是我的命令和配置:

NPROC_PER_NODE=8 nohup xtuner train qwen1_5_32b_chat --deepspeed deepspeed_zero3 > instruct.out 2>&1 &

#######################################################################

PART 1 Settings

#######################################################################

pretrained_model_name_or_path = '/Qwen1.5-32B-Chat' use_varlen_attn = False

data_files = ['test.json'] prompt_template = PROMPT_TEMPLATE.qwen_chat max_length = 200 pack_to_max_length = False

sequence_parallel_size = 1

batch_size = 1 # per_device

accumulative_counts = 1 accumulative_counts *= sequence_parallel_size dataloader_num_workers = 0 max_epochs = 3 optim_type = AdamW lr = 1e-5 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip warmup_ratio = 0.03

save_steps = 500 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited)

#######################################################################

PART 2 Model & Tokenizer

####################################################################### tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, padding_side='right')

model = dict( type=SupervisedFinetune, use_varlen_attn=use_varlen_attn, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16))

#######################################################################

PART 3 Dataset & Dataloader

####################################################################### sampler = SequenceParallelSampler \ if sequence_parallel_size > 1 else DefaultSampler

train_dataset = dict( type=process_hf_dataset, use_varlen_attn=use_varlen_attn, dataset=dict(type=load_dataset, path='json', data_files=data_files), tokenizer=tokenizer, max_length=max_length, dataset_map_fn=None, template_map_fn=dict( type=template_map_fn_factory, template=prompt_template), remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length)

train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=train_dataset, sampler=dict(type=sampler, shuffle=True), collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################

PART 4 Scheduler & Optimizer

#######################################################################

optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16')

param_scheduler = [ dict( type=LinearLR, start_factor=1e-5, by_epoch=True, begin=0, end=warmup_ratio max_epochs, convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, begin=warmup_ratio max_epochs, end=max_epochs, convert_to_iter_based=True) ]

train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################

PART 5 Runtime

####################################################################### custom_hooks = [ dict(type=DatasetInfoHook, tokenizer=tokenizer), dict(type=ThroughputHook) ]

if use_varlen_attn: custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

default_hooks = dict(

record the time of every iteration.

timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
    type=CheckpointHook,
    by_epoch=False,
    interval=save_steps,
    max_keep_ckpts=save_total_limit),
sampler_seed=dict(type=DistSamplerSeedHook),

)

env_cfg = dict(

whether to enable cudnn benchmark

cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),

)

visualizer = None

log_level = 'INFO'

load_from = None

resume = True

randomness = dict(seed=None, deterministic=False)

log_processor = dict(by_epoch=False)

wgs97 commented 2 months ago

另外,在8*A100(80G)的条件下,能否支持32B-Chat + 20K 长度的训练?,

HIT-cwh commented 2 months ago

你好! 全量微调 32B 模型,每张卡上模型状态部分就占据了 32 * 16 / 8 = 64G 的显存,这部分显存占用不会随着序列长度的降低而减少。(计算方法参考 如何选择 ZeRO 策略文档

再加上激活值等其他Tensor很容易导致 OOM 。

如果想使用 8*A100(80G)训练 32B 模型,建议尝试 Lora 或 QLora 算法。如果想训练 20k 上下文,还可以尝试使用 XTuner 的序列并行策略,参考文档

wgs97 commented 2 months ago

你好! 全量微调 32B 模型,每张卡上模型状态部分就占据了 32 * 16 / 8 = 64G 的显存,这部分显存占用不会随着序列长度的降低而减少。(计算方法参考 如何选择 ZeRO 策略文档

再加上激活值等其他Tensor很容易导致 OOM 。

如果想使用 8*A100(80G)训练 32B 模型,建议尝试 Lora 或 QLora 算法。如果想训练 20k 上下文,还可以尝试使用 XTuner 的序列并行策略,参考文档

你好,谢谢回复,之前试过32B在LLAMA-FACTORY在ZERO3是可以训练的(不过长度较短),现在希望在长度增大时使用序列并行。不确定Xtuner和LLAMA-FACTORY两个框架,在QWEN训练上除了序列并行之外有什么不同..

HIT-cwh commented 2 months ago

能不能提供一份详细的报错信息呢?方便我们定位OOM发生在forward还是backward的时候。

另外,也许可以尝试下ZeRO3-Offload (--deepspeed deepspeed_zero3_offload),但可能一定程度上拖慢训练速度。

wgs97 commented 2 months ago

也许我需要试试Megatron..