Open xinwangChen opened 7 months ago
The paper says 29.05 Gflops for XL/4, and this other github issue says 28.8 Gflops. https://github.com/facebookresearch/DiT/issues/47
What do you get for XL/4?
I want know how do you calculate the flops? I can't get the same flops on your paper by thop.
Try fvcore; it's a nice tool to calculate FLOPs.
I want know how do you calculate the flops? I can't get the same flops on your paper by thop.
Try fvcore; it's a nice tool to calculate FLOPs.
You can also refer to the closed issue: https://github.com/facebookresearch/DiT/issues/14
For DiTXL/2,
Latent Size = 32 x 32 x 3 Patch Size = 2 Sequence Length = 4 * (32 // 2) ** 2 = 1024
For now let's focus only the pointwise feedforward network in a DiT block, input_dim = 1152 hidden_dim = input_dim * 4 = 4608 output_dim = 1152
FLOPs for a pointwise feedforward network ignoring FLOPs of GELU activation function is equal to,
21737373696 ~ 21.7374 B or GFLOPs ( MM([1024 x 1152] x [1152 x 4608]) + MM([1024 x 4608] x [4608 x 1152]) )
As we have 28 DiT blocks, total FLOPs of latent transformer is equal to,
21737373696 * 28 = 608646463488 ~ 608.6465 B or GFLOPs
FLOPs DiTXL/2 - 256x256 - 2 (just pointwise feedforward networks in DiT blocks) = 608.6465 B or GFLOPs
Reported in Paper FLOPs DiTXL/2 - 256x256 - 2 = 118.64 B or GFLOPs
From #14 FLOPs DiTXL/2 - 256x256 - 2 = 118.64 x 2 = 237.28 B or GFLOPs
I think FLOPs calculation are way off in the paper.
Please let me know if I have made any mistakes in my calculation.
I want know how do you calculate the flops? I can't get the same flops on your paper by thop.