NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

suspected dkv add/copy bug in context parallel #1054

Closed LeoAtlanto closed 2 months ago

LeoAtlanto commented 2 months ago

Line 2223 in transformer_engine/pytorch/attention.py in main branch(or other branches) is supposed to be if rank == cp_size -1 , or the dkv accumulation in rank-0 and rank-(cp_size-1) would be wrong. Take cp_size=3 for example, where the right logic should be: if rank == 0 and i == 2: dkv.add_(dkv_) if rank == 2 and i == 2: dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]); dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])

image
xrennvidia commented 2 months ago

Hi @LeoAtlanto

thanks for reaching out, but sorry I do not totally understand your plot. could you please elaborate what's the issue?

in current implementation, each bwd step calculate and update the partial dkv of the kv chunk it received. Then the partially updated dkv is send to next GPU for next update. basically dkv of each chunk is accumulated along the communication path.

We have unit test in TE for functionality verification, could you please explain what the buggy case you are encountering? thanks.

LeoAtlanto commented 2 months ago

Hi @LeoAtlanto

thanks for reaching out, but sorry I do not totally understand your plot. could you please elaborate what's the issue?

in current implementation, each bwd step calculate and update the partial dkv of the kv chunk it received. Then the partially updated dkv is send to next GPU for next update. basically dkv of each chunk is accumulated along the communication path.

We have unit test in TE for functionality verification, could you please explain what the buggy case you are encountering? thanks.

Hi @xrennvidia Thanks for your reply. I'll check the unit test first and then elaborate the issue with some detailed debug info if there's indeed some bug.