ZhangGongjie / SAM-DETR

[CVPR'2022] SAM-DETR & SAM-DETR++: Official PyTorch Implementation
MIT License
298 stars 50 forks source link

Can you please provide the GFLOPs calculation code #8

Closed cxq1 closed 2 years ago

cxq1 commented 2 years ago

Thank you for your work! Can you please provide the GFLOPs calculation code

ZhangGongjie commented 2 years ago
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
from fvcore.nn import flop_count_str

# Similar codes in main.py are omitted here for simplicity

    with torch.no_grad():
        model.eval()
        i = 0
        total = 0
        for samples, targets in data_loader_val:
            i += 1
            samples = samples.to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            flops = FlopCountAnalysis(model, samples.tensors)
            num = flops.total()
            total = total + num
            print('**********************************')
            print(num)
            print("Avg by " + str(i) + " samples, averaged FLOPs: ", total / i)
            print('**********************************')
            del samples
            del targets
            del flops
            torch.cuda.empty_cache()
            # print(flop_count_str(flops))
            # print(flops.by_module())
            if i >= 100:
                assert False
ZhangGongjie commented 2 years ago

Thank you for your interest! Hope it could help!
:)

cxq1 commented 2 years ago
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
from fvcore.nn import flop_count_str

# Similar codes in main.py are omitted here for simplicity

    with torch.no_grad():
        model.eval()
        i = 0
        total = 0
        for samples, targets in data_loader_val:
            i += 1
            samples = samples.to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            flops = FlopCountAnalysis(model, samples.tensors)
            num = flops.total()
            total = total + num
            print('**********************************')
            print(num)
            print("Avg by " + str(i) + " samples, averaged FLOPs: ", total / i)
            print('**********************************')
            del samples
            del targets
            del flops
            torch.cuda.empty_cache()
            # print(flop_count_str(flops))
            # print(flops.by_module())
            if i >= 100:
                assert False

Thank you for your reply

ZhangGongjie commented 2 years ago

Issue closed.