zugexiaodui / torch_flops

A library for calculating the FLOPs in the forward() process based on torch.fx
MIT License
87 stars 2 forks source link

The error with Window_reverse in the Swin #4

Closed BitCalSaul closed 10 months ago

BitCalSaul commented 10 months ago

Hi, I tried to test the Swin transformer's performance using torch_flops. But there came an error like this:

Exception has occurred: TypeError
int() argument must be a string, a bytes-like object or a real number, not 'Proxy'
  File "/home/guest/Compressor/Attn/attn_old.py", line 63, in window_reverse
    B = int(windows.shape[0] / (H * W / window_size / window_size))
  File "/home/guest/Compressor/Attn/attn_old.py", line 278, in forward
    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
  File "/home/guest/Compressor/Attn/flopsTest.py", line 45, in forward
    x = layer(x)
  File "/home/guest/Compressor/Attn/flopsTest.py", line 100, in <module>
    flops_counter = TorchFLOPsByFX(SwinModel)
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'Proxy'

I went through it and found the variable "windows" is a Proxy variable just shown in the pic below.

image

The function is from the official repo of Swin https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L45.

This is the minimal code where you could identify this error:

import math
import torch
import torch.nn as nn
from torch_flops import TorchFLOPsByFX
from swin_transformer import SwinTransformerBlock

class MySwinTransformerModel(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size, mlp_ratio, depth):
        super(MySwinTransformerModel, self).__init__()
        self.layers = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                num_heads=num_heads, window_size=window_size,
                                shift_size=0 if (i % 2 == 0) else window_size // 2,
                                mlp_ratio=mlp_ratio)
            for i in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = 'SwinModel'
device = torch.device('cuda:0')
dim = 192
batch_size = 1
num_heads = 1
input_resolution = (256,256)
mlp_ratio = 3.
depth = 2
window_size = 16

SwinModel = MySwinTransformerModel(dim, input_resolution, num_heads, window_size, mlp_ratio, depth).to(device) 

x = torch.randn(batch_size, math.prod(input_resolution), dim).to(device)
print("=" * 30, "Torchflops Report", "=" * 30)
flops_counter = TorchFLOPsByFX(SwinModel)
flops_counter.propagate(x)
flops_counter.print_result_table()
total_flops = flops_counter.print_total_flops(show=True)

Thanks if you need more information plz let me know.

zugexiaodui commented 10 months ago

This problem is caused by windows.shape[0]. The tensor windows is determined during inference, making the graph non-static. More exactly, a tensor is not a node in the graph but a stream "flowing" in the graph (yes, like "tensorflow"), and thus they have no attributes like shape. This is a limitation of torch.fx. The solution to this problem is replacing tensor.shape used in forward() with a pre-defined deterministic value, albeit requiring some manual effort.

zugexiaodui commented 10 months ago

According to my experiment, it can run successfully when tensor.shape is the parameter of reshape() . In other cases (e.g. int(x.shape[0])), the tensor.shape cannot be a parameter of the function.

BitCalSaul commented 10 months ago

Thanks I fixed B to a number instead of the shape[0], then the profiler works well. Next time I will try to fix the model to a static flow model.