Closed LeoAtlanto closed 2 months ago
Hi @LeoAtlanto
thanks for reaching out, but sorry I do not totally understand your plot. could you please elaborate what's the issue?
in current implementation, each bwd step calculate and update the partial dkv of the kv chunk it received. Then the partially updated dkv is send to next GPU for next update. basically dkv of each chunk is accumulated along the communication path.
We have unit test in TE for functionality verification, could you please explain what the buggy case you are encountering? thanks.
Hi @LeoAtlanto
thanks for reaching out, but sorry I do not totally understand your plot. could you please elaborate what's the issue?
in current implementation, each bwd step calculate and update the partial dkv of the kv chunk it received. Then the partially updated dkv is send to next GPU for next update. basically dkv of each chunk is accumulated along the communication path.
We have unit test in TE for functionality verification, could you please explain what the buggy case you are encountering? thanks.
Hi @xrennvidia Thanks for your reply. I'll check the unit test first and then elaborate the issue with some detailed debug info if there's indeed some bug.
Line 2223 in transformer_engine/pytorch/attention.py in main branch(or other branches) is supposed to be
if rank == cp_size -1
, or thedkv
accumulation inrank-0
andrank-(cp_size-1)
would be wrong. Takecp_size=3
for example, where the right logic should be:if rank == 0 and i == 2: dkv.add_(dkv_)
if rank == 2 and i == 2: dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]); dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...])