hpcaitech / Open-Sora

Open-Sora: Democratizing Efficient Video Production for All
https://hpcaitech.github.io/Open-Sora/
Apache License 2.0
20.23k stars 1.91k forks source link

Fix SeqParallelMultiHeadCrossAttention for consistent results in distributed mode #492

Closed Kipsora closed 2 weeks ago

Kipsora commented 2 weeks ago

Currently implementation of SeqParallelMultiHeadCrossAttention will produce different results.

For example, executing python scripts/inference.py configs/opensora-v1-2/inference/sample.py --num-frames 1 --resolution 720p --aspect-ratio 9:16 --prompt "a beautiful waterfall" --verbose 2 produces the following image:

sample_0000

However, running with two GPUs by torchrun --nproc_per_node 2 scripts/inference.py configs/opensora-v1-2/inference/sample.py --num-frames 1 --resolution 720p --aspect-ratio 9:16 --prompt "a beautiful waterfall" --verbose 2 will produce:

sample_0000

While both results look marvelous, it would be better to keep the results consistent among different distributed settings. The reason why they are not consistent is because the tensor Q is not reshaped correctly before conducting all_to_all among different ranks.

If I understand correctly, Q has a shape of [1, (B, SUB_N), NUM_HEADS, HEAD_DIM] before all_to_all, after which we expect Q's shape to be [1, (B, SP, SUB_N), SUB_NUM_HEADS, HEAD_DIM] (where SP denotes the distributed world size). However, all_to_all simply concatentes among the gather dimension. Thus, what we actually get is [1, (SP, B, SUB_N), SUB_NUM_HEADS, HEAD_DIM]. We can fix it either through the proposed changes in this pull request, or conduct an transpose as follows, after which we can observe a consistent result regardless of different distributed settings:

--- a/opensora/models/layers/blocks.py
+++ b/opensora/models/layers/blocks.py
@@ -506,6 +506,9 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):

         # apply all_to_all to gather sequence and split attention heads
         q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1)
+        q = q.view(sp_size, B, SUB_N, self.num_heads // sp_size, self.head_dim)
+        q = q.transpose(0, 1)
+        q = q.contiguous()

         q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
         k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
zhengzangw commented 2 weeks ago

Could @ver217 have a look?