Closed BitCalSaul closed 5 months ago
I have noticed the unstable time measurement result during testing torch_flops, and I find that using CPU is much more unstable compared to using GPU. I have no idea how to solve this problem because it goes into the deep mechanism of the hardware.
The total running time will be supported in the next version. Running multiple batches and taking the average running time is a way for accurate time measurement.
I find that the time of the first operation or the first running of the whole model is inaccurate. You can just run the model once by a line of simple code model(x)
before you initialize TorchFLOPsByFX
, but it will affect the peak GPU memory. I will update the examples in the new version.
Thanks, I will try it tomorrow. I will see if the average time would be similar to DeepSpeed in 50 runs.
OK. For more accurate time measurement, you can first run model(x)
once as introduced in https://github.com/zugexiaodui/torch_flops?tab=readme-ov-file#example-1.
Thanks for your work, by the way, I'm wondering if it's possible to add percentage values for flops and time, which may be more clear for users to know which operation is time/flop-consuming.
Limited by the width of screen, it is not suitable to show too many columns. However, flops_counter.print_result_table()
returns the result table which contains the flops, time, etc. of each operation. You can use the result table to calculate the percentage for flops and time.
Thanks, the width is indeed a limitation. I can't image a better way to show percentage values.
To make sure my script works properly, a reference experiment was done In my server, and the results for vit_base16 and resnet50 are as follow:
total_flops = 35,164,979,282 total_time = 23.244 ms max_memory = 362,289,152 Bytes
total_flops = 8,227,340,288 total_time = 14.515 ms max_memory = 139,527,680 Bytes
It seems like this line of code makes a influence.
import os os.environ['TIMM_FUSED_ATTN'] = "0"
Name | Version |
---|---|
pytorch | 2.1.0 |
pytorch-cuda | 11.8 |
pytorch-mutex | 1.0 |
torch-flops | 0.3.5 |
torchaudio | 2.1.0 |
torchtriton | 2.1.0 |
torchvision | 0.16.0 |
timm | 0.9.12 |
torch-flops | 0.3.5 |
The DeepSpeed is from a branch in a contributor's repo, https://github.com/KimmiShi/DeepSpeed/tree/flops_profiler_attn, which could count @ operation's flops, that cannot be counted in the released repo. Since the first run of DeepSpeed is not correct, I toke an average from 10 runs. For how to accurately get the results from multiple runs of DeepSpeed profiler, check this issue, https://github.com/microsoft/DeepSpeed/issues/4976.
GFLOPs | GMACs | Params | Latency (ms) |
---|---|---|---|
35.153981 | 17.563828 | 86567.656 | 42.323351 |
35.153981 | 17.563828 | 86567.656 | 28.429985 |
35.153981 | 17.563828 | 86567.656 | 27.310371 |
35.153981 | 17.563828 | 86567.656 | 27.332306 |
35.153981 | 17.563828 | 86567.656 | 27.568340 |
35.153981 | 17.563828 | 86567.656 | 28.112888 |
35.153981 | 17.563828 | 86567.656 | 27.656078 |
35.153981 | 17.563828 | 86567.656 | 28.559923 |
35.153981 | 17.563828 | 86567.656 | 28.425694 |
35.153981 | 17.563828 | 86567.656 | 27.448177 |
The average value is: | Metric | Value |
---|---|---|
FLOPs | 35.153981 | |
MACs | 17.563828 | |
Params | 86567.656 | |
Latency | 29.316711 |
GFLOPs | GMACs | Params | Latency (ms) |
---|---|---|---|
8.211108 | 4.089184 | 25557.032 | 24.617195 |
8.211108 | 4.089184 | 25557.032 | 22.443771 |
8.211108 | 4.089184 | 25557.032 | 22.524357 |
8.211108 | 4.089184 | 25557.032 | 22.032738 |
8.211108 | 4.089184 | 25557.032 | 22.349596 |
8.211108 | 4.089184 | 25557.032 | 22.476673 |
8.211108 | 4.089184 | 25557.032 | 22.083759 |
8.211108 | 4.089184 | 25557.032 | 22.483349 |
8.211108 | 4.089184 | 25557.032 | 21.956444 |
8.211108 | 4.089184 | 25557.032 | 22.108555 |
The average value is: | Metric | Value |
---|---|---|
FLOPs | 8.211108 | |
MACs | 4.089184 | |
Params | 25557.032 | |
Latency | 22.507644 |
Model | Method | FLOPs (FLOPs) | Latency (ms) |
---|---|---|---|
vit_base16 | torch_flops | 35,164,979,282 | 23.244 |
resnet50 | torch_flops | 8,227,340,288 | 14.515 |
vit_base16 | DeepSpeed | 35,153,981,000 | 29.317 |
resnet50 | DeepSpeed | 8,211,108,000 | 22.509 |
Take vitb16 as an example, the first run in no_grad() and the next evaluation in no_grad() influenced the results. Here is a table: | Case Number | Case Description | Total FLOPs | Total Time (ms) | Max Memory (Bytes) |
---|---|---|---|---|---|
1 | First run in no_grad() + Evaluation in no_grad() | 35,164,979,282 | 23.092 | 362,289,152 | |
2 | First run in no_grad() | 35,164,979,282 | 27.574 | 509,074,944 | |
3 | Evaluation in no_grad() | 35,164,979,282 | 140.948 | 362,289,152 | |
4 | X + X | 35,164,979,282 | 145.391 | 509,074,944 |
It seems like the case 1 is the most correct one. If the first run wasn't implemented, the time would be strangely higher, for which I don't know the reason. If the following run (or called evaluation) wasn't done in torch.no_grad(), the time and max memory would be higher, for which I guess it's because the model came into the training mode.
As far as DeepSpeed, the function "get_model_profiler" adds a line of code model.eval()
, thus I guess that's how they control the model into the evaluation mode.
However, the test for the Swin block seems strange. This is result from DeepSpeed:
This is from torch_flops. The time is much higher than DeepSpeed's time:
And this is the code:
os.environ['TIMM_FUSED_ATTN']
affects the attention module used in timm
. TIMM_FUSED_ATTN=0
means using the naive operations instead of the optimized fused attention. The optimized fused attention has not been supported in torch_flops, so the TIMM_FUSED_ATTN
is set to 0. Therefore, setting TIMM_FUSED_ATTN
to 1 influences the running speed or memory usage.model.eval()
is also considered in torch_flops. The factor influencing the memory is torch.no_grad
. Without torch.no_grad
, the memory for gradient will be retained during inference.exmple2.py
, it seems that the running speed of a Transformer block is correct. It's strange the inference speed of the Swin block in your code seems too high. Would you like to share your full codes?Sure, this is the code, BTW, the SwinTransformerBlock comes from the official repo, and I didn't change the code.
import math
import torch
import torch.nn as nn
from torch_flops import TorchFLOPsByFX
from swin_attn import SwinTransformerBlock
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
class MySwinTransformerModel(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size, mlp_ratio, depth):
super(MySwinTransformerModel, self).__init__()
self.layers = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio)
for i in range(depth)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
device = torch.device('cuda:0')
dim = 192
batch_size = 1
num_heads = 1
input_resolution = (256,256)
mlp_ratio = 3.
depth = 2
window_size = 16
Swin = MySwinTransformerModel(dim, input_resolution, num_heads, window_size, mlp_ratio, depth).to(device)
x = torch.randn(batch_size, math.prod(input_resolution), dim).to(device)
print(x.shape)
print("=" * 30, "Torchflops Report", "=" * 30)
# note: First run the model once for accurate time measurement in the following process.
# The input `x` and the model should be placed on GPU for memory measurement.
with torch.no_grad():
Swin(x)
with torch.no_grad():
flops_counter = TorchFLOPsByFX(Swin)
flops_counter.propagate(x)
result_table = flops_counter.print_result_table()
total_flops = flops_counter.print_total_flops(show=True)
total_time = flops_counter.print_total_time()
max_memory = flops_counter.print_max_memory()
SwinTransformerBlock
from https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py raises an error when using torch_flops. Do you change the code in SwinTransformerBlock
?
Moreover, do you install kernels.window_process.window_process
for swin_transformer?
Hi, I fixed the shape[0] to a constant number in the function window_reverse. And I didn't install the window process.
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
WindowProcess = None
WindowProcessReverse = None
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
# B = int(windows.shape[0] / (H * W / window_size / window_size))
B = 1
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
fused_window_process=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
The following image shows the result of SwinTransformer code ran on my machine. The running time (10.525 ms) is less than ViT (14.015 ms), which seems reassonable.
My env: Ubuntu Server 22.04 LTS, GPU=NVIDIA GeForce RTX 4090, pytorch=2.1.0, pytorch-cuda=11.8, timm=0.9.8, torch-flops=0.3.5.
for _ in range(10):
with torch.no_grad():
flops_counter = TorchFLOPsByFX(Swin)
flops_counter.propagate(x)
# result_table = flops_counter.print_result_table()
# total_flops = flops_counter.print_total_flops(show=True)
total_time = flops_counter.print_total_time()
# max_memory = flops_counter.print_max_memory()
Thanks for your test. This is really strange. Btw, did you use the model class sent by me?
Yes, I used your code.
flops_test.py
:
import math
import torch
import torch.nn as nn
from torch_flops import TorchFLOPsByFX
from swin_attn import SwinTransformerBlock
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
class MySwinTransformerModel(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size, mlp_ratio, depth):
super(MySwinTransformerModel, self).__init__()
self.layers = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio)
for i in range(depth)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
device = torch.device('cuda:0')
dim = 192
batch_size = 1
num_heads = 1
input_resolution = (256,256)
mlp_ratio = 3.
depth = 2
window_size = 16
Swin = MySwinTransformerModel(dim, input_resolution, num_heads, window_size, mlp_ratio, depth).to(device)
x = torch.randn(batch_size, math.prod(input_resolution), dim).to(device)
print(x.shape)
print("=" * 30, "Torchflops Report", "=" * 30)
# note: First run the model once for accurate time measurement in the following process.
# The input `x` and the model should be placed on GPU for memory measurement.
with torch.no_grad():
Swin(x)
for _ in range(10):
with torch.no_grad():
flops_counter = TorchFLOPsByFX(Swin)
flops_counter.propagate(x)
# result_table = flops_counter.print_result_table()
# total_flops = flops_counter.print_total_flops(show=True)
total_time = flops_counter.print_total_time()
# max_memory = flops_counter.print_max_memory()
swin_attn.py
:
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
WindowProcess = None
WindowProcessReverse = None
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
# B = int(windows.shape[0] / (H * W / window_size / window_size))
B = 1
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
fused_window_process=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
There may be some strange reason since I got a much higher result...
I will check my server...
Okay~
Hi I tried to do a comparison between the flops profiler with torch_flops.
First of all, I found both tools are very accurate in the flops but not for the time.
To avoid the randomness, I run 10 times of the flops profiler for the Swin block. The average is 122.767278e9 flops and 4.13 ms. You could find the time counter is a little unstable from the pic below:
I have submitted this issue to this link, https://github.com/microsoft/DeepSpeed/issues/4976.
As far as torch_flops, since it's a little inconvenient to get the whole time from torch_flops, I just run one time. The flop number is 122,750,500,864 (close to the one from the flops profiler) and the time is 4.8 ms. But if you run the torch_flops in different cases you could find the time counter is more unstable: From the fig below, you could find the normal time for qkv should be around 0.1 ms, but for another time is 18 ms.
Also, sometimes you could find other nodes have unstable counter, which are much higher than the expected time, 4.13 ms.