facebookresearch / fvcore

Collection of common code that's shared among different research projects in FAIR computer vision team.
Apache License 2.0
1.93k stars 226 forks source link

Matmul Flops #108

Open breuera opened 2 years ago

breuera commented 2 years ago

The matmul flop counts seem to be off by 2x.

I tested the code on a simple MLP which reads as:

import torch.nn

## @package eml.mlp.Module
#  Simple MultiLayer Perceptron (MLP) with fixed dimensions.
#
#  The MLP is assumes a 28^2 input-image and 10 output classes.
#  These are the dimensions of the Fashion MNIST dataset.
class Model( torch.nn.Module ):
  ## Initializes the class.
  #  @param self object pointer.
  def __init__( self ):
    super( Model, self ).__init__()
    ## flattens the input
    self.m_flatten = torch.nn.Flatten()
    ## layers of the MLP: 3x(linear + relu)
    self.m_layers = torch.nn.Sequential( torch.nn.Linear( 28*28, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 512 ),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear( 512, 10 ) )

  ## Forward pass with the given input.
  #  @param self object pointer.
  #  @param i_input input for the forward pass.
  #  @return output of the MLP.
  def forward( self,
               i_input ):
    l_flatten = self.m_flatten( i_input )
    l_result = self.m_layers( l_flatten )
    return l_result

Embedded this in some code with the crucial piece here:

l_model = eml.mlp.model.Model()
[...]
print( l_model )

#
# flop count code
# https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md
#
import fvcore.nn

l_x, l_y = next(iter(l_data_loader_train))

print( l_x.size() )

l_flops = fvcore.nn.FlopCountAnalysis( l_model,
                                       l_x )

print( l_flops.by_module_and_operator() )

print( fvcore.nn.flop_count_table( l_flops ) )

This returns:

Model(
  (m_flatten): Flatten(start_dim=1, end_dim=-1)
  (m_layers): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
torch.Size([64, 1, 28, 28])
{'': Counter({'addmm': 42795008}), 'm_flatten': Counter(), 'm_layers': Counter({'addmm': 42795008}), 'm_layers.0': Counter({'addmm': 25690112}), 'm_layers.1': Counter(), 'm_layers.2': Counter({'addmm': 16777216}), 'm_layers.3': Counter(), 'm_layers.4': Counter({'addmm': 327680})}
| module     | #parameters or shape   | #flops   |
|:-----------|:-----------------------|:---------|
| m_layers   | 0.67M                  | 42.795M  |
|  0         |  0.402M                |  25.69M  |
|   0.weight |   (512, 784)           |          |
|   0.bias   |   (512,)               |          |
|  2         |  0.263M                |  16.777M |
|   2.weight |   (512, 512)           |          |
|   2.bias   |   (512,)               |          |
|  4         |  5.13K                 |  0.328M  |
|   4.weight |   (10, 512)            |          |
|   4.bias   |   (10,)                |          |

Let's take the first linear layer as an example: Matrix A in https://pytorch.org/docs/stable/generated/torch.nn.Linear.html has shape (512, 784). Matrix x (since the example batched) has shape (64, 784). Computing the result, C=xA^T requires 2*64*512*784 - 64*512 floating point operations. However, in the example a bias is used, i.e., 64*512 additions on top -> 2*64*512*784=513,80,224 flops total; the tool reports 25,690,112 for the first layer. btw: I am not sure why the bias doesn't show up separately.

I believe that the code below is off since the number of ops of the op C+=AB using BLAS identifiers is 2*M*N*K not M*N*K:

https://github.com/facebookresearch/fvcore/blob/e4f0b3d1fa9ed610a5568932ab7aaf5a37cd75ca/fvcore/nn/jit_handles.py#L225

breuera commented 2 years ago

I appears that https://github.com/facebookresearch/fvcore/issues/69#issue-895213776 raises the same concern.

We count one fused multiply-add as one flop.

I'd consider this to be an unconventional definition. Even knowing this, I don't understand how the ops of the bias fit in there.

ppwwyyxx commented 2 years ago

Different groups adopt different conventions, unfortunately. We implemented the convention in computer vision, which is to use MACs and ignore the flops of bias.

77 would improve this, but unfortunately no one is working on it.