Describe the bug
I am training the LLM with DeepSpeed Pipeline Parallel (ZeRO0 or ZeRO1 used). But I have a tricky issue:
Assuming global_batch_size=4, single machine with 8GPUS, and PP=8, so DP=1, and micro_batch_size=4.
Further assuming the first batch contains the input sequence with shape (4, 2262), and its corresponding hidden_states has a shape (4, 2262, C); the second batch contains the input sequence with shape (4, 2361), and corresponding hidden_states has a shape (4, 2361, C).
But the following RuntimeError occurs:
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([4, 2262, 6144]) and output[0] has a shape of torch.Size([4, 2361, 6144]).
Actually, if I print the tensor shape in my costumized InternLMBlockPipeLayer(nn.Module):
class InternLMBlockPipeLayer(nn.Module):
def __init__(self, model: InternVLChatModel, layer_idx: int, gradient_checkpointing: bool = False):
super().__init__()
self.idx = layer_idx
self.layer = model.language_model.model.layers[layer_idx]
self.gradient_checkpointing = gradient_checkpointing
def forward(self, ipt):
hidden_states, attention_mask, position_ids, labels, random_id = ipt
print('WARNING: ', hidden_states.shape, f'{self.__class__.__name__}.{self.idx}', f'cuda:{dist.get_rank()}', random_id, '\n')
if self.gradient_checkpointing and self.training:
output_attentions = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
deepspeed checkpoint auto use outputs[0] if len(outputs) == 1
outputs = deepspeed.checkpointing.checkpoint(
create_custom_forward(self.layer),
hidden_states,
attention_mask,
position_ids,
None,
)
layer_outputs = [outputs]
else:
layer_outputs = self.layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
return hidden_states, attention_mask, position_ids, labels, random_id
# random_id is defined in collate_fn, and equal to the len(input_ids), just a tag.
The first step (batch (4, 2262) is normal, but during the second step, the log is like this:
As you can see, when across the GPU, print() is also truncated, and it seems the tensor with shape (4, 2361, 6144) are missing when sent from stage0(GPU0) to stage1(GPU1).
What should I do fix it up?
If anything else is required, Please tell me. Thank you very much!
Expected behavior
Correct communication among stages(GPUs)
ds_report output
Please run ds_report to give us details about your setup.
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
OS: Ubuntu 22.04
GPU count and types [8xA100-80G]
Python 3.9
torch-2.1.2(2.0.1 also down)/Deepspeed-0.13.5/cuda-12.2/acclerate-0.31.0/transformers-4.37.2
Describe the bug I am training the LLM with DeepSpeed Pipeline Parallel (ZeRO0 or ZeRO1 used). But I have a tricky issue:
Assuming global_batch_size=4, single machine with 8GPUS, and PP=8, so DP=1, and micro_batch_size=4.
Further assuming the first batch contains the input sequence with shape (4, 2262), and its corresponding hidden_states has a shape (4, 2262, C); the second batch contains the input sequence with shape (4, 2361), and corresponding hidden_states has a shape (4, 2361, C).
we also have the following stage partition:
But the following
RuntimeError
occurs:RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([4, 2262, 6144]) and output[0] has a shape of torch.Size([4, 2361, 6144]).
Actually, if I print the tensor shape in my costumized
InternLMBlockPipeLayer(nn.Module)
:The first step (batch (4, 2262) is normal, but during the second step, the log is like this:
As you can see, when across the GPU,
print()
is also truncated, and it seems the tensor with shape (4, 2361, 6144) are missing when sent from stage0(GPU0) to stage1(GPU1).What should I do fix it up?
If anything else is required, Please tell me. Thank you very much!
Expected behavior Correct communication among stages(GPUs)
ds_report output Please run
ds_report
to give us details about your setup.Screenshots If applicable, add screenshots to help explain your problem.
System info (please complete the following information):