Open wgs97 opened 2 months ago
另外,在8*A100(80G)的条件下,能否支持32B-Chat + 20K 长度的训练?,
你好! 全量微调 32B 模型,每张卡上模型状态部分就占据了 32 * 16 / 8 = 64G 的显存,这部分显存占用不会随着序列长度的降低而减少。(计算方法参考 如何选择 ZeRO 策略文档 )
再加上激活值等其他Tensor很容易导致 OOM 。
如果想使用 8*A100(80G)训练 32B 模型,建议尝试 Lora 或 QLora 算法。如果想训练 20k 上下文,还可以尝试使用 XTuner 的序列并行策略,参考文档
你好! 全量微调 32B 模型,每张卡上模型状态部分就占据了 32 * 16 / 8 = 64G 的显存,这部分显存占用不会随着序列长度的降低而减少。(计算方法参考 如何选择 ZeRO 策略文档 )
再加上激活值等其他Tensor很容易导致 OOM 。
如果想使用 8*A100(80G)训练 32B 模型,建议尝试 Lora 或 QLora 算法。如果想训练 20k 上下文,还可以尝试使用 XTuner 的序列并行策略,参考文档
你好,谢谢回复,之前试过32B在LLAMA-FACTORY在ZERO3是可以训练的(不过长度较短),现在希望在长度增大时使用序列并行。不确定Xtuner和LLAMA-FACTORY两个框架,在QWEN训练上除了序列并行之外有什么不同..
能不能提供一份详细的报错信息呢?方便我们定位OOM发生在forward还是backward的时候。
另外,也许可以尝试下ZeRO3-Offload (--deepspeed deepspeed_zero3_offload),但可能一定程度上拖慢训练速度。
也许我需要试试Megatron..
很奇怪的问题是,在8*A100(80G)上无论我如何设置max-seq,从16000降到200,始终都会OOM。 如下是我的命令和配置:
NPROC_PER_NODE=8 nohup xtuner train qwen1_5_32b_chat --deepspeed deepspeed_zero3 > instruct.out 2>&1 &
#######################################################################
PART 1 Settings
#######################################################################
pretrained_model_name_or_path = '/Qwen1.5-32B-Chat' use_varlen_attn = False
data_files = ['test.json'] prompt_template = PROMPT_TEMPLATE.qwen_chat max_length = 200 pack_to_max_length = False
sequence_parallel_size = 1
batch_size = 1 # per_device
accumulative_counts = 1 accumulative_counts *= sequence_parallel_size dataloader_num_workers = 0 max_epochs = 3 optim_type = AdamW lr = 1e-5 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip warmup_ratio = 0.03
save_steps = 500 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited)
#######################################################################
PART 2 Model & Tokenizer
####################################################################### tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, padding_side='right')
model = dict( type=SupervisedFinetune, use_varlen_attn=use_varlen_attn, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16))
#######################################################################
PART 3 Dataset & Dataloader
####################################################################### sampler = SequenceParallelSampler \ if sequence_parallel_size > 1 else DefaultSampler
train_dataset = dict( type=process_hf_dataset, use_varlen_attn=use_varlen_attn, dataset=dict(type=load_dataset, path='json', data_files=data_files), tokenizer=tokenizer, max_length=max_length, dataset_map_fn=None, template_map_fn=dict( type=template_map_fn_factory, template=prompt_template), remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length)
train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=train_dataset, sampler=dict(type=sampler, shuffle=True), collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
#######################################################################
PART 4 Scheduler & Optimizer
#######################################################################
optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16')
param_scheduler = [ dict( type=LinearLR, start_factor=1e-5, by_epoch=True, begin=0, end=warmup_ratio max_epochs, convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, begin=warmup_ratio max_epochs, end=max_epochs, convert_to_iter_based=True) ]
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
#######################################################################
PART 5 Runtime
####################################################################### custom_hooks = [ dict(type=DatasetInfoHook, tokenizer=tokenizer), dict(type=ThroughputHook) ]
if use_varlen_attn: custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
default_hooks = dict(
record the time of every iteration.
)
env_cfg = dict(
whether to enable cudnn benchmark
)
visualizer = None
log_level = 'INFO'
load_from = None
resume = True
randomness = dict(seed=None, deterministic=False)
log_processor = dict(by_epoch=False)