Closed feiwushiwo closed 4 months ago
Hi @feiwushiwo, it appears that you are encountering numerical instability issues, possibly resulting in NaN values. We recommend turning off amp
or referring to https://github.com/LeapLabTHU/FLatten-Transformer/issues/15#issuecomment-1859776961 for guidance on addressing and mitigating this problem.
amp是关着的后面代码修改的时候出现了问题class FocusedLinearAttention(nn.Module): def init(self, dim, resolution, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., qk_scale=None, focusing_factor=3, kernel_size=5): super().init() self.dim = dim self.dim_out = dim_out or dim self.resolution = resolution self.split_size = split_size self.num_heads = num_heads head_dim = dim // num_heads
# self.scale = qk_scale or head_dim ** -0.5
H_sp, W_sp = self.resolution[0], self.resolution[1]
self.H_sp = H_sp
self.W_sp = W_sp
stride = 1
self.conv_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False)
self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
self.attn_drop = nn.Dropout(attn_drop)
self.focusing_factor = focusing_factor
self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
groups=head_dim, padding=kernel_size // 2)
self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
self.positional_encoding = nn.Parameter(torch.zeros(size=(1, self.H_sp * self.W_sp, dim)))
def im2cswin(self, x):
B, N, C = x.shape
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
x = img2windows(x, self.H_sp, self.W_sp)
# x = x.reshape(-1, self.H_sp * self.W_sp, C).contiguous()
return x
def get_lepe(self, x, func):
B, N, C = x.shape
H = W = int(np.sqrt(N))
x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
H_sp, W_sp = self.H_sp, self.W_sp
x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W'
lepe = func(x) ### B', C, H', W'
lepe = lepe.reshape(-1, C // self.num_heads, H_sp * W_sp).permute(0, 2, 1).contiguous()
x = x.reshape(-1, C, self.H_sp * self.W_sp).permute(0, 2, 1).contiguous()
return x, lepe
def forward(self, qkv):
"""
x: B C H W
"""
qkv = self.conv_qkv(qkv)
q, k, v = torch.chunk(qkv.flatten(2).transpose(1, 2), 3, dim=-1)
### Img2Window
H, W = self.resolution
B, L, C = q.shape
assert L == H * W, "flatten img_tokens has wrong size"
q = self.im2cswin(q)
k = self.im2cswin(k)
v, lepe = self.get_lepe(v, self.get_v)
k = k + self.positional_encoding
focusing_factor = self.focusing_factor
kernel_function = nn.ReLU()
q = kernel_function(q) + 1e-6
k = kernel_function(k) + 1e-6
scale = nn.Softplus()(self.scale)
q = q / scale
k = k / scale
q_norm = q.norm(dim=-1, keepdim=True)
k_norm = k.norm(dim=-1, keepdim=True)
if float(focusing_factor) <= 6:
q = q ** focusing_factor
k = k ** focusing_factor
else:
q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factor
k = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factor
q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
with torch.autocast(enabled=False, device_type='cuda'):
q, k, v = q.float(), k.float(), v.float()
z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
if i * j * (c + d) > c * d * (i + j):
kv = torch.einsum("b j c, b j d -> b c d", k, v)
x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
else:
qk = torch.einsum("b i c, b j c -> b i j", q, k)
x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
num = int(v.shape[1] ** 0.5)
feature_map = rearrange(v, "b (h w) c -> b c h w", w=num, h=num)
feature_map = feature_map.float()
feature_map = rearrange(self.dwc(feature_map), "b c h w -> b (h w) c")
x = x + feature_map
x = x + lepe
x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
x = windows2img(x, self.H_sp, self.W_sp, H, W).permute(0, 3, 1, 2)
return x代码如上错误为
麻烦博主请您帮我看一下错误原因
There is an inconsistency in the data types – the model weight type is Half, while the input tensor type is float. I recommend reviewing and adjusting the model weight type to ensure compatibility.
We have fixed numerical instability problem, and now the model can be trained with auto mixed precision (amp).
我想问实验效果真的很一般而且也不清楚为什么会这样