laekov / fastmoe

A fast MoE impl for PyTorch
https://fastmoe.ai
Apache License 2.0
1.56k stars 187 forks source link

前向传播返回值缺少bal_loss #209

Open tisgotos opened 1 month ago

tisgotos commented 1 month ago

在应用完补丁执行pretrain_gpt.py遇到的问题 Traceback (most recent call last): File "pretrain_gpt.py", line 126, in pretrain(train_valid_test_datasets_provider, model_provider, forward_step, File "/workspace/Megatron-LM/megatron/training.py", line 157, in pretrain iteration = train(forward_step_func, File "/workspace/Megatron-LM/megatron/training.py", line 630, in train train_step(forward_step_func, File "/workspace/Megatron-LM/megatron/training.py", line 377, in train_step losses_reduced = forward_backward_func( File "/workspace/Megatron-LM/megatron/schedules.py", line 132, in forward_backward_no_pipelining output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model, File "/workspace/Megatron-LM/megatron/schedules.py", line 61, in forward_step output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model) ValueError: not enough values to unpack (expected 3, got 2)

pretrain_gpt源码:

def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers()

# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
    data_iterator)
timers('batch-generator').stop()

output_tensor = model(tokens, position_ids, attention_mask,
                      labels=labels)

return output_tensor, partial(loss_func, loss_mask)
tisgotos commented 1 month ago

在对pretrain_gpt代码进行修改之后(应用patch.py:from fmoe.megatron.patch import patch_loss_func_v2_5, patch_forward_step),用单gpu训练时显示cuda out of memory,请问这种情况有对应的解决办法吗

[before the start of training step] datetime: 2024-09-08 19:33:33 Traceback (most recent call last): File "pretrain_gpt.py", line 128, in pretrain(train_valid_test_datasets_provider, model_provider, forward_step, File "/workspace/Megatron-LM/megatron/training.py", line 157, in pretrain iteration = train(forward_step_func, File "/workspace/Megatron-LM/megatron/training.py", line 630, in train train_step(forward_step_func, File "/workspace/Megatron-LM/megatron/training.py", line 377, in train_step losses_reduced = forward_backward_func( File "/workspace/Megatron-LM/megatron/schedules.py", line 132, in forward_backward_no_pipelining output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model, File "/workspace/Megatron-LM/megatron/schedules.py", line 61, in forward_step output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model) File "pretrain_gpt.py", line 104, in forward_step return patched_forward_step(data_iterator, model) File "pretrain_gpt.py", line 104, in forward_step return patched_forward_step(data_iterator, model) File "pretrain_gpt.py", line 104, in forward_step return patched_forward_step(data_iterator, model) [Previous line repeated 14 more times] File "pretrain_gpt.py", line 100, in forward_step output_tensor = model(tokens, position_ids, attention_mask, File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/opt/conda/lib/python3.8/site-packages/fastmoe-1.1.0-py3.8-linux-x86_64.egg/fmoe/distributed.py", line 114, in forward return self.module(*args, *kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/workspace/Megatron-LM/megatron/model/module.py", line 172, in forward outputs = self.module(*inputs, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/workspace/Megatron-LM/megatron/model/gpt_model.py", line 96, in forward lm_output = self.language_model( File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/workspace/Megatron-LM/megatron/model/language_model.py", line 351, in forward encoder_output = self.encoder(encoder_input, File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/workspace/Megatron-LM/megatron/model/transformer.py", line 662, in forward hidden_states = self._checkpointed_forward(hidden_states, File "/workspace/Megatron-LM/megatron/model/transformer.py", line 616, in _checkpointed_forward hidden_states = mpu.checkpoint( File "/workspace/Megatron-LM/megatron/mpu/random.py", line 319, in checkpoint return CheckpointFunction.apply(function, args) File "/workspace/Megatron-LM/megatron/mpu/random.py", line 262, in forward outputs = run_function(args) File "/workspace/Megatron-LM/megatron/model/transformer.py", line 608, in customforward x = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/workspace/Megatron-LM/megatron/model/transformer.py", line 448, in forward self.self_attention(layernorm_output, File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/workspace/Megatron-LM/megatron/model/transformer.py", line 308, in forward attention_probs = self.attention_dropout(attention_probs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/dropout.py", line 58, in forward return F.dropout(input, self.p, self.training, self.inplace) File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 1252, in dropout return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training) RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 23.69 GiB total capacity; 22.31 GiB already allocated; 22.81 MiB free; 22.46 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

laekov commented 1 month ago

oom 说明模型或者中间结果太大了. 建议换个小点的模型.