bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
371 stars 48 forks source link

Fixed MQA outputs not matching with HF model with non-flash case #71

Closed mayank31398 closed 1 year ago

mayank31398 commented 1 year ago

Flash Attention is working correctly and I see errors between HF model's layers and Megatron model's layers as low as 1e-3 to 1e-4 with fp16 precision. However, with non-flash case, there are large errors due to incorrect shape handling during training.

jlamypoirier commented 1 year ago

The existing implementation looks fine to me, see layer outputs below. Of course, the different shape will cause intermediate values and the dropout mask to be different, but the end result is the same when dropout is disabled.

CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc-per-node=1 Megatron-LM/pretrain_gpt.py \
--tokenizer-type=TokenizerFromFile \
--tokenizer-file=[...] \
--make-vocab-size-divisible-by=128 \
--num-workers=0 \
--valid-num-workers=0 \
--data-path=[...] \
--num-layers=1 \
--hidden-size=512 \
--num-attention-heads=4 \
--attention-head-type=multiquery \
--max-position-embeddings=32 \
--seq-length=32 \
--init-method-std=0.022 \
--DDP-impl=local  \
--initial-loss-scale=65536 \
--fp16 \
--train-iters=1 \
--micro-batch-size=2 \
--log-interval=1 \
--eval-iters=0 \
--lr=0.0002 \
[--use-flash-attn \]
--hidden-dropout=0 \
--attention-dropout=0 \
--lr-decay-style=constant

With flash (printing transformer layer output stats and every 997th value):

LAYER 1, name=None, shape=[32, 2, 512], dtype=torch.float16, device=cuda:0, stats=(8.166, 210.750), storage=140294837774848, storage size=65536, storage stats=(8.166, 210.750)
[-0.427978515625, 0.037841796875, 0.058929443359375, 0.26025390625, -0.2462158203125, -0.172119140625, 0.08416748046875, 0.2264404296875, -0.332763671875, 0.12109375, 0.107177734375, -0.071044921875, 0.189697265625, 0.178955078125, 0.239990234375, -0.1292724609375, -0.2047119140625, 0.28662109375, 0.0889892578125, -0.1063232421875, -0.115478515625, -0.16552734375, -0.145751953125, -0.10693359375, 0.388671875, -0.08074951171875, -0.14697265625, 0.183837890625, -0.1710205078125, -0.03802490234375, -0.11138916015625, 0.10986328125, 0.048828125]

Without flash:

LAYER 1, name=None, shape=[32, 2, 512], dtype=torch.float16, device=cuda:0, stats=(8.164, 210.750), storage=140037864947712, storage size=65536, storage stats=(8.164, 210.750)
[-0.427978515625, 0.037841796875, 0.05889892578125, 0.260009765625, -0.246337890625, -0.172119140625, 0.0838623046875, 0.226318359375, -0.332763671875, 0.12109375, 0.107666015625, -0.0709228515625, 0.18994140625, 0.178955078125, 0.2401123046875, -0.1290283203125, -0.2047119140625, 0.28662109375, 0.0887451171875, -0.1063232421875, -0.1156005859375, -0.1656494140625, -0.1458740234375, -0.1070556640625, 0.388671875, -0.080810546875, -0.14697265625, 0.1837158203125, -0.1712646484375, -0.0379638671875, -0.111328125, 0.10986328125, 0.04888916015625]
jlamypoirier commented 1 year ago

The interest in the HF format (other than being simpler) is to reduce the number of copying transposes, but that needs a batch-first data format. I'm not sure there is much to do with sequence-first (needed for sequence parallelism)

jlamypoirier commented 1 year ago

Also I think Alibi would need to be adjusted

mayank31398 commented 1 year ago

@jlamypoirier @RaymondLi0 I found the bug. The alibi tensor is incorrectly being repeated. Rather the sequence of steps in this PR lead to the correct alibi tensor.

Earlier the tensor was [b, sq * np, sk rather than [b, np * sq, sk]

RaymondLi0 commented 1 year ago

Nice catch, thank you @mayank31398 !