Open xju2 opened 2 months ago
Facing the same issue, please let me know if you have found a fix! Thanks!
in Megatron-DeepSpeed/megatron/model/bert_model.py,there is a line:
extended_attention_mask = bert_extended_attention_mask(attention_mask)
which bert_extended_attention_mask
is define like:
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
return extended_attention_mask
the attention_mask
is extended from [b,s] to [b,1,s,s].
Is this the cause of the problem? If so, how can I fix it?
Used the docker image: nvcr.io/nvidia/pytorch:23.12-py3 Megatron-LM commit ID: c4d12e2
in Megatron-DeepSpeed/megatron/model/bert_model.py,there is a line:
extended_attention_mask = bert_extended_attention_mask(attention_mask)
which
bert_extended_attention_mask
is define like:def bert_extended_attention_mask(attention_mask): # We create a 3D attention mask from a 2D tensor mask. # [b, 1, s] attention_mask_b1s = attention_mask.unsqueeze(1) # [b, s, 1] attention_mask_bs1 = attention_mask.unsqueeze(2) # [b, s, s] attention_mask_bss = attention_mask_b1s * attention_mask_bs1 # [b, 1, s, s] extended_attention_mask = attention_mask_bss.unsqueeze(1) # Convert attention mask to binary: extended_attention_mask = (extended_attention_mask < 0.5) return extended_attention_mask
the
attention_mask
is extended from [b,s] to [b,1,s,s]. Is this the cause of the problem? If so, how can I fix it?Used the docker image: nvcr.io/nvidia/pytorch:23.12-py3 Megatron-LM commit ID: c4d12e2
use Megatron-LM branch 23.08
with docker image: nvcr.io/nvidia/pytorch:23.08-py3 can avoid this problem.
Describe the bug Runing the Pretraining BERT encountered two issues:
--attention-softmax-in-fp32
to the model arguments. This applies to Pretraining GPTpretrain_gpt.sh
too.[B, 1, max_seqlen, max_seqlen]
; however, the functionget_cu_seqlens
expects its shape to be[B, 1, 1, max_seqlen]
. The training crashes. See the log below.To Reproduce run the example:
./examples/pretrain_bert.sh
in the docker imagenvcr.io/nvidia/pytorch:24.02-py3
with themain
branch of Megatron-LM. The issues was found in thecore_r0.6.0
branch too.Expected behavior expect the example runs out of box.
Stack trace/logs
Environment (please complete the following information): Used the docker image:
nvcr.io/nvidia/pytorch:24.02-py3
.ccfeda4
2.3.0a0+ebedce2
12.3
2.20.3
Proposed fix N/A
Additional context N/A