facebookresearch / DiT

Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
Other
6.19k stars 551 forks source link

How do you calculate flops? #72

Open xinwangChen opened 7 months ago

xinwangChen commented 7 months ago

I want know how do you calculate the flops? I can't get the same flops on your paper by thop.

pzpzpzp2 commented 7 months ago

The paper says 29.05 Gflops for XL/4, and this other github issue says 28.8 Gflops. https://github.com/facebookresearch/DiT/issues/47

image

What do you get for XL/4?

Rayjryang commented 6 months ago

I want know how do you calculate the flops? I can't get the same flops on your paper by thop.

Try fvcore; it's a nice tool to calculate FLOPs.

Rayjryang commented 6 months ago

I want know how do you calculate the flops? I can't get the same flops on your paper by thop.

Try fvcore; it's a nice tool to calculate FLOPs.

You can also refer to the closed issue: https://github.com/facebookresearch/DiT/issues/14

SaudxInu commented 5 months ago

For DiTXL/2,

Latent Size = 32 x 32 x 3 Patch Size = 2 Sequence Length = 4 * (32 // 2) ** 2 = 1024

For now let's focus only the pointwise feedforward network in a DiT block, input_dim = 1152 hidden_dim = input_dim * 4 = 4608 output_dim = 1152

FLOPs for a pointwise feedforward network ignoring FLOPs of GELU activation function is equal to,

21737373696 ~ 21.7374 B or GFLOPs ( MM([1024 x 1152] x [1152 x 4608]) + MM([1024 x 4608] x [4608 x 1152]) )

As we have 28 DiT blocks, total FLOPs of latent transformer is equal to,

21737373696 * 28 = 608646463488 ~ 608.6465 B or GFLOPs

FLOPs DiTXL/2 - 256x256 - 2 (just pointwise feedforward networks in DiT blocks) = 608.6465 B or GFLOPs

Reported in Paper FLOPs DiTXL/2 - 256x256 - 2 = 118.64 B or GFLOPs

From #14 FLOPs DiTXL/2 - 256x256 - 2 = 118.64 x 2 = 237.28 B or GFLOPs

I think FLOPs calculation are way off in the paper.

Please let me know if I have made any mistakes in my calculation.