yanzq95 / SGNet

SGNet: Structure Guided Network via Gradient-Frequency Awareness for Depth Map Super-Resolution (AAAI-2024)
Apache License 2.0
90 stars 3 forks source link

FLOPs calculation #12

Closed makhovds closed 11 hours ago

makhovds commented 3 weeks ago

Good day, thank you for sharing the code and the full pipeline for training and testing of your network. I have a question on the calculation of Flops in table 5 in your paper. Using the following simple snippet:

from thop import profile, clever_format

    scale = 8
    H = 480
    W = 640
    num_feats = 40

    model = SGNet(num_feats=num_feats, kernel_size=3, scale=scale)
    model.eval()

    # Input to the model
    image = torch.randn(1, 3, H, W, requires_grad=True)
    depth = torch.randn(1, 1, H // scale, W // scale, requires_grad=True)

    macs, params = profile(model, inputs=((image, depth), ), 
                        custom_ops={})
    macs, params = clever_format([macs, params], "%.3f")
    print('=' * 100)
    print('=' * 30, ' SGNet ', '=' * 30)
    print('=' * 100)
    print(macs)

I achieve 7.22 TFLOPs instead of 4623.9 GFLOPs with this snippet. Could you please clarify how did you estimate the FLOPs for your network?

yanzq95 commented 3 weeks ago

The inference time in the table is calculated on the NYU dataset, while the FLOPs are calculated on the RGB-D-D dataset with a resolution of 512*384. You can use the following code to calculate it:

import torch
from thop import profile

import models.SGNet as Net

model = Net.SGNet(num_feats=40, kernel_size=3, scale=8)

rgb = torch.randn(1, 3, 384, 512)
d = torch.randn(1, 1, 48, 64)
input = (rgb,d)

flops, params = profile(model, (input,))

print('+++++++++++++++++++++++++++++++')
print('params: %.2f M, flops: %.2f G' % (params / 1e6, flops / 1e9))
print('+++++++++++++++++++++++++++++++')