facebookresearch / fvcore

Collection of common code that's shared among different research projects in FAIR computer vision team.
Apache License 2.0
1.93k stars 226 forks source link

flop count analysis of LSTM layers #98

Open HendrikKlug-synthara opened 2 years ago

HendrikKlug-synthara commented 2 years ago

Hello, it seems that LSTM layers are not yet supported for the fvcore.nn.FlopCountAnalysis method:

import torch
from fvcore.nn import FlopCountAnalysis
from torch import nn

class ToyLSTMModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.LSTM(10, 20, 1)

    def forward(self, x):
        h0 = torch.randn(1, 3, 20)
        c0 = torch.randn(1, 3, 20)
        output, _ = self.rnn(x, (h0, c0))

        return output

model = ToyLSTMModel()
example_input = torch.randn(5, 3, 10)

print(FlopCountAnalysis(model, example_input).by_module())

gives:

Unsupported operator aten::randn encountered 2 time(s)
Unsupported operator aten::lstm encountered 1 time(s)
Counter({'': 0, 'rnn': 0})

While the same works an LSTM cell:

import torch
from fvcore.nn import FlopCountAnalysis
from torch import nn

class ToyLSTMModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.LSTMCell(10, 20)

    def forward(self, x):
        hx = torch.randn(3, 20)
        cx = torch.randn(3, 20)
        output = []
        for i in range(x.size()[0]):
            hx, cx = self.rnn(x[i], (hx, cx))
            output.append(hx)
        output = torch.stack(output, dim=0)
        return output

model = ToyLSTMModel()
example_input = torch.randn(5, 3, 10)

print(FlopCountAnalysis(model, example_input).by_module())

output:

Unsupported operator aten::randn encountered 2 time(s)
Unsupported operator aten::add_ encountered 10 time(s)
Unsupported operator aten::unsafe_chunk encountered 5 time(s)
Unsupported operator aten::sigmoid_ encountered 15 time(s)
Unsupported operator aten::tanh_ encountered 5 time(s)
Unsupported operator aten::mul encountered 15 time(s)
Unsupported operator aten::tanh encountered 5 time(s)
Counter({'': 36000, 'rnn': 36000})

Is there any particular reason for that? The number of FLOPS of the LSTM layer should be the same than from one LSTM cell times the number of time steps.

ppwwyyxx commented 2 years ago

Is there any particular reason for that?

Because LSTMCell uses matmul / linear under the hood while LSTM is a separate op. The flop formula for LSTM ops is not implemented. That's why there is a warning Unsupported operator aten::lstm encountered 1 time(s).

Drglitch791 commented 2 months ago

Is there a way to run these functions silently without generetaing warnings?