sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.82k stars 307 forks source link

Request to include FLOP count for Graph Convolutions #71

Open pranavgundewar opened 3 years ago

sovrasov commented 3 years ago

There is no standard module in torch.nn representing graph convolutions, while ptflops can account pytorch's modules only. You can also write a custom hook for your GCN implementation and pass it to ptflops.

pranavgundewar commented 2 years ago

@sovrasov Can you share some examples of writing a custom hook for GCN implementation?

Thank you!

sovrasov commented 2 years ago

Hi! Here is a brief example:

class MyModule(nn.Module):
    def forward(self, x):
        return x
def my_module_flops_counter_hook(module, input, output):
    module.__flops__ += 0

model = MyModule()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=True,
                                           verbose=True,
                                           custom_modules_hooks={MyModule: my_module_flops_counter_hook})

Instead of MyModule you could substitute your GCN implementation.