facebookresearch / fvcore

Collection of common code that's shared among different research projects in FAIR computer vision team.
Apache License 2.0
2k stars 226 forks source link

Counting FLOPS for a custom op with set_op_handle: a toy example that doesn't work. #147

Open guynich opened 5 months ago

guynich commented 5 months ago

I'm extending the given example for fvcore.nn.FlopCountAnalysis to add flops count of a custom op within my model class.

import torch

from collections import Counter

from fvcore.nn import FlopCountAnalysis
from torch import nn

class TestModel(nn.Module):
    """Toy model."""
    def __init__(self):
        super().__init__()
        self.act = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=1)
        self.fc = nn.Linear(in_features=1000, out_features=10)

    def forward(self, x):
        _ = self.custom_op_flop_counter(inputs=x, outputs=None)
        return self.fc(self.act(self.conv(x)).flatten(1))

    @staticmethod
    # Has no access to anything else in the class.
    def custom_op_flop_counter(inputs, outputs) -> Counter:
        """Returns counter value to include in flops."""
        # The function should return a counter object with per-operator statistics.
        return Counter({'custom_op': 500})

model = TestModel()
inputs = (torch.randn((1, 3, 10, 10)),)

flops = FlopCountAnalysis(
    model,
    inputs).set_op_handle(
        "custom_op", model.custom_op_flop_counter)

print(flops.by_module_and_operator())

The "custom_op" and its returned value of 500 are not seen in the print statement. It does print the expected values for the linear and conv operators, e.g.: {'': Counter({'linear': 10000, 'conv': 3000}), 'act': Counter(), 'conv': Counter({'conv': 3000}), 'fc': Counter({'linear': 10000})}. What am I doing wrong here that prevents my custom op count being included?