LeapLabTHU / MLLA

Official repository of MLLA (NeurIPS 2024)
179 stars 6 forks source link

Potential Inaccuracy in FLOPs Computation #18

Closed wullia closed 3 months ago

wullia commented 3 months ago

I have been evaluating the MLLA-Tiny model and noticed a potential discrepancy in the FLOPs computation as reported in the MLLA paper. Using fvcore , the computed FLOPs for MLLA-Tiny are approximately 4.16G, aligning with the figures reported in the paper. However, upon a detailed examination, it appears that the FLOPs for the linear attention components are not fully accounted for. Specifically, the operations transpose(K)V and Q transpose(K)V seem to be omitted from the FLOPs calculation. Below are the logs from the FLOPs calculation, focusing on the stem and the first stage parts, where these discrepancies appear:

MLLA(
  #params: 24.39M, #flops: 4.16G
  (patch_embed): Stem(
    #params: 0.11M, #flops: 0.54G
    (conv1): ConvLayer(
      #params: 0.93K, #flops: 12.85M
      (conv): Conv2d(
        3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
        #params: 0.86K, #flops: 10.84M
      )
      (norm): BatchNorm2d(
        32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        #params: 64, #flops: 2.01M
      )
      (act): ReLU()
    )
    (conv2): Sequential(
      #params: 18.56K, #flops: 0.24G
      (0): ConvLayer(
        #params: 9.28K, #flops: 0.12G
        (conv): Conv2d(
          32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          #params: 9.22K, #flops: 0.12G
        )
        (norm): BatchNorm2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          #params: 64, #flops: 2.01M
        )
        (act): ReLU()
      )
      (1): ConvLayer(
        #params: 9.28K, #flops: 0.12G
        (conv): Conv2d(
          32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          #params: 9.22K, #flops: 0.12G
        )
        (norm): BatchNorm2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          #params: 64, #flops: 2.01M
        )
      )
    )
    (conv3): Sequential(
      #params: 90.75K, #flops: 0.29G
      (0): ConvLayer(
        #params: 74.24K, #flops: 0.24G
        (conv): Conv2d(
          32, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
          #params: 73.73K, #flops: 0.23G
        )
        (norm): BatchNorm2d(
          256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          #params: 0.51K, #flops: 4.01M
        )
        (act): ReLU()
      )
      (1): ConvLayer(
        #params: 16.51K, #flops: 52.38M
        (conv): Conv2d(
          256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
          #params: 16.38K, #flops: 51.38M
        )
        (norm): BatchNorm2d(
          64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          #params: 0.13K, #flops: 1M
        )
      )
    )
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    #params: 23.77M, #flops: 3.62G
    (0): BasicLayer(
      dim=64, input_resolution=(56, 56), depth=2
      #params: 0.22M, #flops: 0.54G
      (blocks): ModuleList(
        #params: 0.11M, #flops: 0.38G
        (0): MLLABlock(
          dim=64, input_resolution=(56, 56), num_heads=2, mlp_ratio=4.0
          #params: 56.7K, #flops: 0.19G
          (cpe1): Conv2d(
            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
            #params: 0.64K, #flops: 1.81M
          )
          (norm1): LayerNorm(
            (64,), eps=1e-05, elementwise_affine=True
            #params: 0.13K, #flops: 1M
          )
          (in_proj): Linear(
            in_features=64, out_features=64, bias=True
            #params: 4.16K, #flops: 12.85M
          )
          (act_proj): Linear(
            in_features=64, out_features=64, bias=True
            #params: 4.16K, #flops: 12.85M
          )
          (dwc): Conv2d(
            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
            #params: 0.64K, #flops: 1.81M
          )
          (act): SiLU()
          (attn): LinearAttention(
            dim=64, num_heads=2
            #params: 8.96K, #flops: 40.54M
            (qk): Linear(
              in_features=64, out_features=128, bias=True
              #params: 8.32K, #flops: 25.69M
            )
            (elu): ELU(alpha=1.0)
            (lepe): Conv2d(
              64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
              #params: 0.64K, #flops: 1.81M
            )
            (rope): RoPE()
          )
          (out_proj): Linear(
            in_features=64, out_features=64, bias=True
            #params: 4.16K, #flops: 12.85M
          )
          (drop_path): Identity(#params: 0, #flops: N/A)
          (cpe2): Conv2d(
            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
            #params: 0.64K, #flops: 1.81M
          )
          (norm2): LayerNorm(
            (64,), eps=1e-05, elementwise_affine=True
            #params: 0.13K, #flops: 1M
          )
          (mlp): Mlp(
            #params: 33.09K, #flops: 0.1G
            (fc1): Linear(
              in_features=64, out_features=256, bias=True
              #params: 16.64K, #flops: 51.38M
            )
            (act): GELU(approximate='none')
            (fc2): Linear(
              in_features=256, out_features=64, bias=True
              #params: 16.45K, #flops: 51.38M
            )
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): MLLABlock(
          dim=64, input_resolution=(56, 56), num_heads=2, mlp_ratio=4.0
          #params: 56.7K, #flops: 0.19G
          (cpe1): Conv2d(
            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
            #params: 0.64K, #flops: 1.81M
          )
          (norm1): LayerNorm(
            (64,), eps=1e-05, elementwise_affine=True
            #params: 0.13K, #flops: 1M
          )
          (in_proj): Linear(
            in_features=64, out_features=64, bias=True
            #params: 4.16K, #flops: 12.85M
          )
          (act_proj): Linear(
            in_features=64, out_features=64, bias=True
            #params: 4.16K, #flops: 12.85M
          )
          (dwc): Conv2d(
            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
            #params: 0.64K, #flops: 1.81M
          )
          (act): SiLU()
          (attn): LinearAttention(
            dim=64, num_heads=2
            #params: 8.96K, #flops: 40.54M
            (qk): Linear(
              in_features=64, out_features=128, bias=True
              #params: 8.32K, #flops: 25.69M
            )
            (elu): ELU(alpha=1.0)
            (lepe): Conv2d(
              64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
              #params: 0.64K, #flops: 1.81M
            )
            (rope): RoPE()
          )
          (out_proj): Linear(
            in_features=64, out_features=64, bias=True
            #params: 4.16K, #flops: 12.85M
          )
          (drop_path): DropPath()
          (cpe2): Conv2d(
            64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
            #params: 0.64K, #flops: 1.81M
          )
          (norm2): LayerNorm(
            (64,), eps=1e-05, elementwise_affine=True
            #params: 0.13K, #flops: 1M
          )
          (mlp): Mlp(
            #params: 33.09K, #flops: 0.1G
            (fc1): Linear(
              in_features=64, out_features=256, bias=True
              #params: 16.64K, #flops: 51.38M
            )
            (act): GELU(approximate='none')
            (fc2): Linear(
              in_features=256, out_features=64, bias=True
              #params: 16.45K, #flops: 51.38M
            )
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )

Could you please help address this concern? If my observations are correct, I suggest updating the FLOPs calculations in your paper to ensure a fair comparison with other models.

Thank you for looking into this matter.

tian-qing001 commented 3 months ago

Hi @wullia. The FLOPs for $K^\top V$ and $Q(K^\top V)$ is calculated by fvcore , but not listed. For example, in your log, the FLOPs of a linear attention block is larger than the sum of FLOPs of all the components listed.

          (attn): LinearAttention(
            dim=64, num_heads=2
            #params: 8.96K, #flops: 40.54M
            (qk): Linear(
              in_features=64, out_features=128, bias=True
              #params: 8.32K, #flops: 25.69M
            )
            (elu): ELU(alpha=1.0)
            (lepe): Conv2d(
              64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
              #params: 0.64K, #flops: 1.81M
            )
            (rope): RoPE()
          )

I believe the FLOPs reported in our paper are accurate.

wullia commented 3 months ago

Got it, thx for your reply!