Closed exhyy closed 2 months ago
No, it is because we were using chunk parallel size = 2. In fact, one GPU takes 16K tokens (4 chunks x 4096 tokens/chunk). And we split the 32K context into two GPUs in a chunk parallel group.
According to my understanding, the result from the first GPU needs to be passed to the CEMA layer and Timestep Normalization layer of the second GPU, which means the second GPU can actually see the information from the first GPU. Therefore, the physical chunk size is 4, but the logical chunk size is 8. Is this correct?
Yes, it is correct.
Thanks for your excellent work!
In Section 4.1 of the paper, the attention chunk size c is set to 4K with 4 chunks, which results in total context length to 32K:
Why isn't the number of chunks 8? Is it a typo?