zugexiaodui / torch_flops

A library for calculating the FLOPs in the forward() process based on torch.fx
MIT License
57 stars 1 forks source link

results compared with DeepSpeed flops profiler #6

Closed BitCalSaul closed 5 months ago

BitCalSaul commented 5 months ago

Hi I tried to do a comparison between the flops profiler with torch_flops.

First of all, I found both tools are very accurate in the flops but not for the time.

To avoid the randomness, I run 10 times of the flops profiler for the Swin block. The average is 122.767278e9 flops and 4.13 ms. You could find the time counter is a little unstable from the pic below:

image

I have submitted this issue to this link, https://github.com/microsoft/DeepSpeed/issues/4976.

As far as torch_flops, since it's a little inconvenient to get the whole time from torch_flops, I just run one time. The flop number is 122,750,500,864 (close to the one from the flops profiler) and the time is 4.8 ms. But if you run the torch_flops in different cases you could find the time counter is more unstable: From the fig below, you could find the normal time for qkv should be around 0.1 ms, but for another time is 18 ms.

image image

Also, sometimes you could find other nodes have unstable counter, which are much higher than the expected time, 4.13 ms.

image
zugexiaodui commented 5 months ago

I have noticed the unstable time measurement result during testing torch_flops, and I find that using CPU is much more unstable compared to using GPU. I have no idea how to solve this problem because it goes into the deep mechanism of the hardware.

The total running time will be supported in the next version. Running multiple batches and taking the average running time is a way for accurate time measurement.

zugexiaodui commented 5 months ago

I find that the time of the first operation or the first running of the whole model is inaccurate. You can just run the model once by a line of simple code model(x) before you initialize TorchFLOPsByFX, but it will affect the peak GPU memory. I will update the examples in the new version.

BitCalSaul commented 5 months ago

Thanks, I will try it tomorrow. I will see if the average time would be similar to DeepSpeed in 50 runs.

zugexiaodui commented 5 months ago

OK. For more accurate time measurement, you can first run model(x) once as introduced in https://github.com/zugexiaodui/torch_flops?tab=readme-ov-file#example-1.

BitCalSaul commented 5 months ago

Thanks for your work, by the way, I'm wondering if it's possible to add percentage values for flops and time, which may be more clear for users to know which operation is time/flop-consuming.

zugexiaodui commented 5 months ago

Limited by the width of screen, it is not suitable to show too many columns. However, flops_counter.print_result_table() returns the result table which contains the flops, time, etc. of each operation. You can use the result table to calculate the percentage for flops and time.

BitCalSaul commented 5 months ago

Thanks, the width is indeed a limitation. I can't image a better way to show percentage values.

BitCalSaul commented 5 months ago

Reference experiment:

To make sure my script works properly, a reference experiment was done In my server, and the results for vit_base16 and resnet50 are as follow:

vit_base16

total_flops = 35,164,979,282 total_time = 23.244 ms max_memory = 362,289,152 Bytes

resnet50

total_flops = 8,227,340,288 total_time = 14.515 ms max_memory = 139,527,680 Bytes

It seems like this line of code makes a influence. import os os.environ['TIMM_FUSED_ATTN'] = "0"

my environment:

Name Version
pytorch 2.1.0
pytorch-cuda 11.8
pytorch-mutex 1.0
torch-flops 0.3.5
torchaudio 2.1.0
torchtriton 2.1.0
torchvision 0.16.0
timm 0.9.12
torch-flops 0.3.5

The results from DeepSpeed

The DeepSpeed is from a branch in a contributor's repo, https://github.com/KimmiShi/DeepSpeed/tree/flops_profiler_attn, which could count @ operation's flops, that cannot be counted in the released repo. Since the first run of DeepSpeed is not correct, I toke an average from 10 runs. For how to accurately get the results from multiple runs of DeepSpeed profiler, check this issue, https://github.com/microsoft/DeepSpeed/issues/4976.

vit_base16

GFLOPs GMACs Params Latency (ms)
35.153981 17.563828 86567.656 42.323351
35.153981 17.563828 86567.656 28.429985
35.153981 17.563828 86567.656 27.310371
35.153981 17.563828 86567.656 27.332306
35.153981 17.563828 86567.656 27.568340
35.153981 17.563828 86567.656 28.112888
35.153981 17.563828 86567.656 27.656078
35.153981 17.563828 86567.656 28.559923
35.153981 17.563828 86567.656 28.425694
35.153981 17.563828 86567.656 27.448177
The average value is: Metric Value
FLOPs 35.153981
MACs 17.563828
Params 86567.656
Latency 29.316711

resnet50

GFLOPs GMACs Params Latency (ms)
8.211108 4.089184 25557.032 24.617195
8.211108 4.089184 25557.032 22.443771
8.211108 4.089184 25557.032 22.524357
8.211108 4.089184 25557.032 22.032738
8.211108 4.089184 25557.032 22.349596
8.211108 4.089184 25557.032 22.476673
8.211108 4.089184 25557.032 22.083759
8.211108 4.089184 25557.032 22.483349
8.211108 4.089184 25557.032 21.956444
8.211108 4.089184 25557.032 22.108555
The average value is: Metric Value
FLOPs 8.211108
MACs 4.089184
Params 25557.032
Latency 22.507644

Comparison

Model Method FLOPs (FLOPs) Latency (ms)
vit_base16 torch_flops 35,164,979,282 23.244
resnet50 torch_flops 8,227,340,288 14.515
vit_base16 DeepSpeed 35,153,981,000 29.317
resnet50 DeepSpeed 8,211,108,000 22.509

Some observation

Take vitb16 as an example, the first run in no_grad() and the next evaluation in no_grad() influenced the results. Here is a table: Case Number Case Description Total FLOPs Total Time (ms) Max Memory (Bytes)
1 First run in no_grad() + Evaluation in no_grad() 35,164,979,282 23.092 362,289,152
2 First run in no_grad() 35,164,979,282 27.574 509,074,944
3 Evaluation in no_grad() 35,164,979,282 140.948 362,289,152
4 X + X 35,164,979,282 145.391 509,074,944

It seems like the case 1 is the most correct one. If the first run wasn't implemented, the time would be strangely higher, for which I don't know the reason. If the following run (or called evaluation) wasn't done in torch.no_grad(), the time and max memory would be higher, for which I guess it's because the model came into the training mode.

As far as DeepSpeed, the function "get_model_profiler" adds a line of code model.eval(), thus I guess that's how they control the model into the evaluation mode.

BitCalSaul commented 5 months ago

However, the test for the Swin block seems strange. This is result from DeepSpeed:

image

This is from torch_flops. The time is much higher than DeepSpeed's time:

image

And this is the code:

image
zugexiaodui commented 5 months ago
  1. os.environ['TIMM_FUSED_ATTN'] affects the attention module used in timm. TIMM_FUSED_ATTN=0 means using the naive operations instead of the optimized fused attention. The optimized fused attention has not been supported in torch_flops, so the TIMM_FUSED_ATTN is set to 0. Therefore, setting TIMM_FUSED_ATTN to 1 influences the running speed or memory usage.
  2. model.eval() is also considered in torch_flops. The factor influencing the memory is torch.no_grad. Without torch.no_grad, the memory for gradient will be retained during inference.
  3. According to the result of exmple2.py, it seems that the running speed of a Transformer block is correct. It's strange the inference speed of the Swin block in your code seems too high. Would you like to share your full codes?
BitCalSaul commented 5 months ago

Sure, this is the code, BTW, the SwinTransformerBlock comes from the official repo, and I didn't change the code.

import math
import torch
import torch.nn as nn
from torch_flops import TorchFLOPsByFX
from swin_attn import SwinTransformerBlock
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

class MySwinTransformerModel(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size, mlp_ratio, depth):
        super(MySwinTransformerModel, self).__init__()
        self.layers = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                num_heads=num_heads, window_size=window_size,
                                shift_size=0 if (i % 2 == 0) else window_size // 2,
                                mlp_ratio=mlp_ratio)
            for i in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

device = torch.device('cuda:0')
dim = 192
batch_size = 1
num_heads = 1
input_resolution = (256,256)
mlp_ratio = 3.
depth = 2
window_size = 16

Swin = MySwinTransformerModel(dim, input_resolution, num_heads, window_size, mlp_ratio, depth).to(device) 
x = torch.randn(batch_size, math.prod(input_resolution), dim).to(device)
print(x.shape)
print("=" * 30, "Torchflops Report", "=" * 30)
# note: First run the model once for accurate time measurement in the following process.
# The input `x` and the model should be placed on GPU for memory measurement.
with torch.no_grad():
    Swin(x)
with torch.no_grad():
    flops_counter = TorchFLOPsByFX(Swin)
    flops_counter.propagate(x)
result_table = flops_counter.print_result_table()
total_flops = flops_counter.print_total_flops(show=True)
total_time = flops_counter.print_total_time()
max_memory = flops_counter.print_max_memory()
zugexiaodui commented 5 months ago

SwinTransformerBlock from https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py raises an error when using torch_flops. Do you change the code in SwinTransformerBlock?

zugexiaodui commented 5 months ago

Moreover, do you install kernels.window_process.window_process for swin_transformer?

BitCalSaul commented 5 months ago

Hi, I fixed the shape[0] to a constant number in the function window_reverse. And I didn't install the window process.

# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

WindowProcess = None
WindowProcessReverse = None

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    # B = int(windows.shape[0] / (H * W / window_size / window_size))
    B = 1
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 fused_window_process=False):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)
        self.fused_window_process = fused_window_process

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        # assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        # reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops
zugexiaodui commented 5 months ago

The following image shows the result of SwinTransformer code ran on my machine. The running time (10.525 ms) is less than ViT (14.015 ms), which seems reassonable. image My env: Ubuntu Server 22.04 LTS, GPU=NVIDIA GeForce RTX 4090, pytorch=2.1.0, pytorch-cuda=11.8, timm=0.9.8, torch-flops=0.3.5.

zugexiaodui commented 5 months ago
for _ in range(10):
    with torch.no_grad():
        flops_counter = TorchFLOPsByFX(Swin)
        flops_counter.propagate(x)
    # result_table = flops_counter.print_result_table()
    # total_flops = flops_counter.print_total_flops(show=True)
    total_time = flops_counter.print_total_time()
    # max_memory = flops_counter.print_max_memory()

image

BitCalSaul commented 5 months ago

Thanks for your test. This is really strange. Btw, did you use the model class sent by me?

zugexiaodui commented 5 months ago

Yes, I used your code.

flops_test.py:

import math
import torch
import torch.nn as nn
from torch_flops import TorchFLOPsByFX
from swin_attn import SwinTransformerBlock
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

class MySwinTransformerModel(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size, mlp_ratio, depth):
        super(MySwinTransformerModel, self).__init__()
        self.layers = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                num_heads=num_heads, window_size=window_size,
                                shift_size=0 if (i % 2 == 0) else window_size // 2,
                                mlp_ratio=mlp_ratio)
            for i in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

device = torch.device('cuda:0')
dim = 192
batch_size = 1
num_heads = 1
input_resolution = (256,256)
mlp_ratio = 3.
depth = 2
window_size = 16

Swin = MySwinTransformerModel(dim, input_resolution, num_heads, window_size, mlp_ratio, depth).to(device) 
x = torch.randn(batch_size, math.prod(input_resolution), dim).to(device)
print(x.shape)
print("=" * 30, "Torchflops Report", "=" * 30)
# note: First run the model once for accurate time measurement in the following process.
# The input `x` and the model should be placed on GPU for memory measurement.
with torch.no_grad():
    Swin(x)

for _ in range(10):
    with torch.no_grad():
        flops_counter = TorchFLOPsByFX(Swin)
        flops_counter.propagate(x)
    # result_table = flops_counter.print_result_table()
    # total_flops = flops_counter.print_total_flops(show=True)
    total_time = flops_counter.print_total_time()
    # max_memory = flops_counter.print_max_memory()

swin_attn.py:

# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

WindowProcess = None
WindowProcessReverse = None

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    # B = int(windows.shape[0] / (H * W / window_size / window_size))
    B = 1
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 fused_window_process=False):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)
        self.fused_window_process = fused_window_process

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        # assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        # reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops
BitCalSaul commented 5 months ago

There may be some strange reason since I got a much higher result...

image

I will check my server...

zugexiaodui commented 5 months ago

Okay~