Closed zhousai-zs closed 6 months ago
图4的mask的值是经过归一化后得到的。而且mask的分布也受图像和不同层的影响,比如较小的没有明显图像复杂度差异的图片,mask确实会比较随机,而且为了确保没有块效应,mask会在一些层弥补那些简单区域或关注边缘区域。图4中的mask应该是使用的pred_score的初始值加归一化,第8,9,10个predictor会比较稳定,如下图:
首先感谢您的耐心回复,像您一样,我保存了每个过程中产生的mask,一共20张,并使用它对原图进行处理,使用了CAMixerSRx4_DF.pth和CAMixerSRx4.pth两种预训练权重,但是并没有出现和你一样的结果,表现十分零散,是否在代码操作层面有问题,下图是我的所有代码细节。 保存所有mask和pred_score 加载图片和模型,在测试这张图的过程中保存下mask和pred_score 使用mask对原图进行mask操作 使用所有mask处理原图并保存
此外,我还根据您的说明自己训练了一次,在单张4090上训练了五个小时,使用这个训练好的模型的结果也与上述两种预训练模型差不多,希望可以解答我的疑问,谢谢!
代码好像有点问题,这样直接乘和我们预设的块的分割方式没有对齐,从而导致看起来随机。在推理的时候我们不是用mask直接做掩码乘法的,而是通过argsort得到前k个用attention的块的索引,再通过索引用batch_index_select函数得到计算attention的token,你可以试一下这个代码:
def visual(x,mask):
_, _, h, w = x.size()
wsize = 16
i=16
mod_pad_h = (wsize - h % wsize) % wsize
mod_pad_w = (wsize - w % wsize) % wsize
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
res = x
mean_y = x*0
B,C,H,W = x.shape
origin_x = x
x = rearrange(x,'b c (h dh) (w dw) -> b (h w) (dh dw c)', dh=wsize, dw=wsize)
x_input = x
print(x.shape)
p = len(mask)
for i in range(p):
# i=18
idx1,idx2 = mask[i][0]
offset = mask[i][1]
x1,x2 = batch_index_select(x_input,idx1), batch_index_select(x_input,idx2)
print(x1.shape[1],x_input.shape[1])
x2 = x2*0.1 + 0.9
x1 = x1*1.0
x = batch_index_fill(x_input.clone(), x1, x2, idx1, idx2)
x = rearrange(x,'b (h w) (dh dw c) -> b c (h dh) (w dw)', h=H//wsize, w=W//wsize, dh=wsize, dw=wsize, c = C)
x_warp = flow_warp(origin_x, offset.permute(0,2,3,1), interp_mode='bilinear', padding_mode='border')
# x = torch.max(x,res)
img1 = transforms.ToPILImage()(x.squeeze(0))
# img1.show()
img1.save('figures/8/temp{}.png'.format(i))
完整的如下:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import torchvision.transforms as transforms
# from DCNv2_latest.dcn_v2 import DCNv2
from basicsr.archs.arch_util import to_2tuple, trunc_normal_, flow_warp
from PIL import Image
from basicsr.utils.registry import ARCH_REGISTRY
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ElementScale(nn.Module):
"""A learnable element-wise scaler."""
def __init__(self, embed_dims, init_value=0., requires_grad=True):
super(ElementScale, self).__init__()
self.scale = nn.Parameter(
init_value * torch.ones((1, embed_dims, 1, 1)),
requires_grad=requires_grad
)
def forward(self, x):
return x * self.scale
def ones(tensor):
if tensor is not None:
tensor.data.fill_(0.5)
def zeros(tensor):
if tensor is not None:
tensor.data.fill_(0.0)
def batch_index_select(x, idx):
if len(x.size()) == 3:
B, N, C = x.size()
N_new = idx.size(1)
offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
idx = idx + offset
out = x.reshape(B*N, C)[idx.reshape(-1)].reshape(B, N_new, C)
return out
elif len(x.size()) == 2:
B, N = x.size()
N_new = idx.size(1)
offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
idx = idx + offset
out = x.reshape(B*N)[idx.reshape(-1)].reshape(B, N_new)
return out
else:
raise NotImplementedError
def batch_index_fill(x, x1, x2, idx1, idx2):
B, N, C = x.size()
B, N1, C = x1.size()
B, N2, C = x2.size()
offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1)
idx1 = idx1 + offset * N
idx2 = idx2 + offset * N
x = x.reshape(B*N, C)
x[idx1.reshape(-1)] = x1.reshape(B*N1, C)
x[idx2.reshape(-1)] = x2.reshape(B*N2, C)
x = x.reshape(B, N, C)
return x
class PredictorLG(nn.Module):
""" Importance Score Predictor
"""
def __init__(self, dim, window_size=8, k=4):
super().__init__()
self.window_size = window_size
cdim = dim + k
embed_dim = window_size**2
self.in_conv = nn.Sequential(
nn.Conv2d(cdim, cdim//4, 1),
LayerNorm(cdim//4),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
)
self.out_offsets = nn.Sequential(
nn.Conv2d(cdim//4, cdim//8, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(cdim//8, 2, 1),
)
self.out_mask = nn.Sequential(
nn.Linear(embed_dim, window_size),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Linear(window_size, 2),
nn.Softmax(dim=-1)
)
self.out_CA = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(cdim//4, dim, 1),
nn.Sigmoid(),
)
self.out_SA = nn.Sequential(
nn.Conv2d(cdim//4, 1, 3, 1, 1),
nn.Sigmoid(),
)
def forward(self, input_x, mask=None, ratio=0.5, train_mode=False):
x = self.in_conv(input_x)
offsets = self.out_offsets(x)
offsets = offsets.tanh().mul(8.0)
ca = self.out_CA(x)
sa = self.out_SA(x)
x = torch.mean(x, keepdim=True, dim=1)
x = rearrange(x,'b c (h dh) (w dw) -> b (h w) (dh dw c)', dh=self.window_size, dw=self.window_size)
B, N, C = x.size()
pred_score = self.out_mask(x)
mask = F.gumbel_softmax(pred_score, hard=True, dim=2)[:, :, 0:1]
if self.training or train_mode:
return mask, offsets, ca, sa
else:
ratio = 1.0
score = pred_score[:, : , 0]
B, N = score.shape
r = torch.mean(mask,dim=(0,1))
num_keep_node = min(N,int(N * r * ratio))
idx = torch.argsort(score, dim=1, descending=True)
idx1 = idx[:, :num_keep_node]
idx2 = idx[:, num_keep_node:]
return [idx1, idx2], offsets, ca, sa
class PostNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.norm(self.fn(x) + x)
class CAMixer(nn.Module):
def __init__(self, dim, window_size=8, bias=True, is_deformable=True, ratio=0.5):
super().__init__()
self.dim = dim
self.window_size = window_size
self.is_deformable = is_deformable
self.ratio = ratio
k = 3
d = 2
self.project_v = nn.Conv2d(dim, dim, 1, 1, 0, bias = bias)
self.project_q = nn.Linear(dim, dim, bias = bias)
self.project_k = nn.Linear(dim, dim, bias = bias)
# self.sim_attn = nn.Conv2d(dim, dim, kernel_size=3, bias=True, groups=dim, padding=1)
# Conv
# self.conv_sptial = nn.Conv2d(dim, dim, kernel_size=3, bias=True, groups=dim, padding=1)
self.conv_sptial = nn.Sequential(
nn.Conv2d(dim, dim, k, padding=k//2, groups=dim),
nn.Conv2d(dim, dim, k, stride=1, padding=((k//2)*d), groups=dim, dilation=d))
self.project_out = nn.Conv2d(dim, dim, 1, 1, 0, bias = bias)
self.act = nn.GELU()
# Predictor
self.route = PredictorLG(dim,window_size)
def forward(self,x,condition_global=None, mask=None, train_mode=False):
N,C,H,W = x.shape
v = self.project_v(x)
if self.is_deformable:
condition_wind = torch.stack(torch.meshgrid(torch.linspace(-1,1,self.window_size),torch.linspace(-1,1,self.window_size)))\
.type_as(x).unsqueeze(0).repeat(N, 1, H//self.window_size, W//self.window_size)
if condition_global is None:
_condition = torch.cat([v, condition_wind], dim=1)
else:
_condition = torch.cat([v, condition_global, condition_wind], dim=1)
mask, offsets, ca, sa = self.route(_condition,ratio=self.ratio,train_mode=train_mode)
q = x
k = x + flow_warp(x, offsets.permute(0,2,3,1), interp_mode='bilinear', padding_mode='border')
qk = torch.cat([q,k],dim=1)
# Attn branch
vs = v*sa
v = rearrange(v,'b c (h dh) (w dw) -> b (h w) (dh dw c)', dh=self.window_size, dw=self.window_size)
vs = rearrange(vs,'b c (h dh) (w dw) -> b (h w) (dh dw c)', dh=self.window_size, dw=self.window_size)
qk = rearrange(qk,'b c (h dh) (w dw) -> b (h w) (dh dw c)', dh=self.window_size, dw=self.window_size)
if self.training or train_mode:
N_ = v.shape[1]
v1,v2 = v*mask, vs*(1-mask)
qk1 = qk*mask
else:
idx1, idx2 = mask
_, N_ = idx1.shape
v1,v2 = batch_index_select(v,idx1),batch_index_select(vs,idx2)
qk1 = batch_index_select(qk,idx1)
v1 = rearrange(v1,'b n (dh dw c) -> (b n) (dh dw) c', n=N_, dh=self.window_size, dw=self.window_size)
qk1 = rearrange(qk1,'b n (dh dw c) -> b (n dh dw) c', n=N_, dh=self.window_size, dw=self.window_size)
q1,k1 = torch.chunk(qk1,2,dim=2)
q1 = self.project_q(q1)
k1 = self.project_k(k1)
q1 = rearrange(q1,'b (n dh dw) c -> (b n) (dh dw) c', n=N_, dh=self.window_size, dw=self.window_size)
k1 = rearrange(k1,'b (n dh dw) c -> (b n) (dh dw) c', n=N_, dh=self.window_size, dw=self.window_size)
#calculate attention: Softmax(Q@K)@V
attn = q1 @ k1.transpose(-2, -1)
attn = attn.softmax(dim=-1)
f_attn = attn@v1
f_attn = rearrange(f_attn,'(b n) (dh dw) c -> b n (dh dw c)',
b=N, n=N_, dh=self.window_size, dw=self.window_size)
if not (self.training or train_mode):
attn_out = batch_index_fill(v.clone(), f_attn, v2.clone(), idx1, idx2)
else:
attn_out = f_attn + v2
attn_out = rearrange(
attn_out, 'b (h w) (dh dw c) -> b (c) (h dh) (w dw)',
h=H//self.window_size, w=W//self.window_size, dh=self.window_size, dw=self.window_size
)
out = attn_out
out = self.act(self.conv_sptial(out))*ca + out
out = self.project_out(out)
if self.training:
return out, torch.mean(mask,dim=1)
return out, [mask,offsets]
class GatedFeedForward(nn.Module):
def __init__(self, dim, mult = 1, bias=False, dropout = 0.):
super().__init__()
hidden_features = int(dim*mult)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
class Block(nn.Module):
def __init__(self, n_feats, window_size=8, ratio=0.5):
super(Block,self).__init__()
self.n_feats = n_feats
self.norm1 = LayerNorm(n_feats)
self.mixer = CAMixer(n_feats,window_size=window_size,ratio=ratio)
self.norm2 = LayerNorm(n_feats)
self.ffn = GatedFeedForward(n_feats)
def forward(self,x, global_condition=None):
if self.training:
res, decision = self.mixer(x,global_condition)
x = self.norm1(x+res)
res = self.ffn(x)
x = self.norm2(x+res)
return x, decision
else:
res, decision = self.mixer(x,global_condition)
x = self.norm1(x+res)
res = self.ffn(x)
x = self.norm2(x+res)
return x, decision
class Group(nn.Module):
def __init__(self, n_feats, n_block, window_size=8, ratio=0.5):
super(Group, self).__init__()
self.n_feats = n_feats
self.body = nn.ModuleList([Block(n_feats, window_size=window_size, ratio=ratio) for i in range(n_block)])
self.body_tail = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
def forward(self,x,condition_global=None):
decision = []
shortcut = x.clone()
if self.training:
for _, blk in enumerate(self.body):
x, mask = blk(x,condition_global)
decision.append(mask)
x = self.body_tail(x) + shortcut
return x, decision
else:
for _, blk in enumerate(self.body):
x, mask = blk(x,condition_global)
decision.append(mask)
x = self.body_tail(x) + shortcut
return x, decision
#@ARCH_REGISTRY.register()
class OSR(nn.Module):
def __init__(self, n_block=4, n_group=4, n_colors=3, n_feats=60, scale=4):
super().__init__()
self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
self.window_sizes = [16,16]
ratios = [0.5 for _ in range(n_group)]
blocks = [4,4,6,6]
self.global_predictor = nn.Sequential(nn.Conv2d(n_feats, 8, 1, 1, 0, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(8, 2, 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True))
self.scale = scale
# define body module
self.body = nn.ModuleList([Group(n_feats, n_block=blocks[i], window_size=self.window_sizes[i%2], ratio=ratios[i]) for i in range(n_group)])
self.body_tail = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
# define tail module
self.tail = nn.Sequential(
nn.Conv2d(n_feats, n_colors*(scale**2), 3, 1, 1),
nn.PixelShuffle(scale)
)
def forward(self, x):
decision = []
H, W = x.shape[2:]
x = self.check_image_size(x)
x = self.head(x)
condition_global = self.global_predictor(x)
shortcut = x.clone()
for _, blk in enumerate(self.body):
x, mask = blk(x,condition_global)
decision.extend(mask)
x = self.body_tail(x)
x = x + shortcut
x = self.tail(x)
if self.training:
return x[:, :, 0:H*self.scale, 0:W*self.scale], torch.mean(torch.cat(decision,dim=1),dim=(0,1))
else:
return x[:, :, 0:H*self.scale, 0:W*self.scale], decision
def check_image_size(self, x):
_, _, h, w = x.size()
wsize = self.window_sizes[0]
for i in range(1, len(self.window_sizes)):
wsize = wsize*self.window_sizes[i] // math.gcd(wsize, self.window_sizes[i])
mod_pad_h = (wsize - h % wsize) % wsize
mod_pad_w = (wsize - w % wsize) % wsize
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def visual(x,mask):
_, _, h, w = x.size()
wsize = 16
i=16
mod_pad_h = (wsize - h % wsize) % wsize
mod_pad_w = (wsize - w % wsize) % wsize
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
res = x
mean_y = x*0
B,C,H,W = x.shape
origin_x = x
x = rearrange(x,'b c (h dh) (w dw) -> b (h w) (dh dw c)', dh=wsize, dw=wsize)
x_input = x
print(x.shape)
p = len(mask)
for i in range(p):
# i=18
idx1,idx2 = mask[i][0]
offset = mask[i][1]
x1,x2 = batch_index_select(x_input,idx1), batch_index_select(x_input,idx2)
print(x1.shape[1],x_input.shape[1])
x2 = x2*0.1 + 0.9
x1 = x1*1.0
x = batch_index_fill(x_input.clone(), x1, x2, idx1, idx2)
x = rearrange(x,'b (h w) (dh dw c) -> b c (h dh) (w dw)', h=H//wsize, w=W//wsize, dh=wsize, dw=wsize, c = C)
x_warp = flow_warp(origin_x, offset.permute(0,2,3,1), interp_mode='bilinear', padding_mode='border')
# x = torch.max(x,res)
img1 = transforms.ToPILImage()(x.squeeze(0))
# img1.show()
img1.save('figures/8/temp{}.png'.format(i))
if __name__ == '__main__':
f = OSR(scale=4).cuda()
f.load_state_dict(torch.load(r'C:\Z_Develop\CAMixer\CAMixer\pretrained_models\LSR\CAMixerSRx4.pth')['params_ema'], strict=True)
f.eval()
img = Image.open(r'C:\Z_Document\dataset\test2k\test2k\LR\X4\1298.png').convert('RGB')
x = transforms.ToTensor()(img)
x = x.unsqueeze(0).cuda()
p1,mask = f(x)
img1 = transforms.ToPILImage()(p1.squeeze(0))
# img1.show()
img1.save('figures/8/SR.png')
visual(x,mask)
十分感谢您的解惑!
没事没事
我试图复现文章中的mask的作用,文章中图4展现了一些图像不同复杂度对应的概率,但是在复现过程中我发现可视化出的mask十分随机,并不能表现出图像的复杂程度,mask的上一步的pred_score也没有我想要的表现。我的问题是想知道作者在可视化图4中的概率和mask与图片复杂度对应关系上使用的是哪一部分的值?这个值应该从整个SR模型中的第几个predictor中得到?Looking forward to your reply, thanks.