Closed melohux closed 3 years ago
I report all flops data in the paper using the tool "thop".
After seeing your comment, I check it again by writing my own flops-counting code as follow:
def ddf_flops(ch, fs, ks=3, r=0.2, downsample=False):
ks2 = ks ** 2
fs2 = fs ** 2
rch = int(ch*r)
flops = fs2 * ch + (ch * rch + rch * ch * ks2) + ch * ks2 # global pool + channel filters + filter norm
if downsample:
fs2 = (fs // 2) ** 2 # stride == 2
flops += (ch * ks2 * fs2) + ks2 * fs2 # spatial filters + filter norm
flops += (ch * ks2 * fs2) * 2 # combine and apply
return flops
def block_flops(in_ch, mid_ch, out_ch, fs, downsample=False):
fs2 = fs ** 2
flops = in_ch * mid_ch * fs2
flops += mid_ch * fs2 # BN
flops += ddf_flops(mid_ch, fs, downsample=downsample)
if downsample:
fs2 = (fs // 2) ** 2 # stride == 2
if in_ch != out_ch or downsample:
flops += in_ch * out_ch * fs2
flops += out_ch * fs2 # BN
flops += mid_ch * fs2 # BN
flops += mid_ch * out_ch * fs2
flops += out_ch * fs2 # BN
flops += out_ch * fs2 # x + res(x)
return flops
total_flops = 0
# stem
total_flops += 3 * 64 * 7 * 7 * 112 * 112
total_flops += 64 * 112 * 112 # BN
# stem pooling
total_flops += 64 * 2 * 2 * 56 * 56
# layer 1 (0)
total_flops += block_flops(64, 64, 256, 56)
# layer 1 (1-2)
total_flops += block_flops(256, 64, 256, 56) * 2
# layer2 (0)
total_flops += block_flops(256, 128, 512, 56, True)
# layer2 (1-3)
total_flops += block_flops(512, 128, 512, 28) * 3
# layer3 (0)
total_flops += block_flops(512, 256, 1024, 28, True)
# layer3 (1-5)
total_flops += block_flops(1024, 256, 1024, 14) * 5
# layer4 (0)
total_flops += block_flops(1024, 512, 2048, 14, True)
# layer4 (1-2)
total_flops += block_flops(2048, 512, 2048, 7) * 2
total_flops += 7 * 7 * 2048 # pooling
total_flops += 2048 * 1000
print(total_flops)
I got 2.3B using my own code too. Please double-check my flops-counting code and comment if there is anything wrong. Thanks!
Thank you for your reply and I think your calculation is basically correct. You have missed some FLOPs for activation function but yes it should be correct for the 2.3B presented in the paper.
Thank you for your excellent work! But I'm wondering how you obtain the 2.3B FLOPs as claimed in your paper. I've tested the FLOPs of the ResNet50 without your ddf module and it already takes 2.3B. As there are some conv2d with channel interaction in your spatial branch, I think the whole FLOPs of your ddfnet is more than 2.3B, which should be about 2.7B under rough estimation.