Closed nviwch closed 8 months ago
Hi @nviwch, thank you for your interest in our work. Training with mixed precision occasionally results in NaN values on certain devices. To troubleshoot this issue, consider disabling mixed precision in the Flatten module. For instance, you can replace the forward function in Flatten-Swin with the following code:
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
q, k, v = qkv.unbind(0)
k = k + self.positional_encoding
focusing_factor = self.focusing_factor
kernel_function = nn.ReLU()
q = kernel_function(q) + 1e-6
k = kernel_function(k) + 1e-6
scale = nn.Softplus()(self.scale)
q = q / scale
k = k / scale
q_norm = q.norm(dim=-1, keepdim=True)
k_norm = k.norm(dim=-1, keepdim=True)
if float(focusing_factor) <= 6:
q = q ** focusing_factor
k = k ** focusing_factor
else:
q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factor
k = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factor
q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
with torch.autocast(enabled=False, device_type='cuda'):
q, k, v = q.float(), k.float(), v.float()
z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
if i * j * (c + d) > c * d * (i + j):
kv = torch.einsum("b j c, b j d -> b c d", k, v)
x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
else:
qk = torch.einsum("b i c, b j c -> b i j", q, k)
x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
num = int(v.shape[1] ** 0.5)
feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
x = x + feature_map
x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
x = self.proj(x)
x = self.proj_drop(x)
return x
Currently, we haven't experimented with identity/convolution as an alternative to depthwise convolution in the attention block, which could be a valuable avenue for future investigation and improvement.
我在調試過程發現計算kv 時,取代v成v/64,z變成z*64,epsilon=1e-3可以避免float16數值爆掉,而數式上是等價的,只是分子分母各自倍大了 另外也發現你把原本q_norm, k_norm乘成q和k,好像是沒有分別的,等同分子分母一同倍大。
以下是我逐行簡化,保留了你們的想法,也為了加速減少了許多
class SimplifiedFocusedLinearAttention(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.gen_qkv = nn.Conv1d(dim, dim*3, 1, bias=False)
self.proj = nn.Conv1d(dim, dim, 1, bias=False)
def forward(self, x): # 輸入形狀是B, C, N,由圖像B, C, H, W變成,為了速度不作transpose
B, _, N = x.shape
qkv = self.gen_qkv(x)
qkv.relu_() # 改成 qkv都有非線性,因為下面我直接加v
q, k, v = torch.split(qkv, self.dim, 1) # k: (b, c, n) v: (b, c, n)
# 只能用二次方,三次方怕爆炸
# 這樣的話,有點像Squared ReLU
q = f.normalize(q.square())
k = f.normalize(k.square())
q = q.reshape((-1, self.head_dim, N))
k = k.reshape((-1, self.head_dim, N))
v = v.reshape((-1, self.head_dim, N))
z = (q.transpose(1,2) @ k.sum(dim=-1)[...,None])[...,0] + 1e-3
z = z.reciprocal().mul(64) # 平衡乘法 1e3 * 64 < 65504 (float16極限)
kv = (v/64) @ k.transpose(1,2) # 避免 kv 爆炸
x = kv @ q * z[:,None,:] + v # 改成加v而已
x = x.reshape((B, -1, N))
x = self.proj(x)
return x
Hi @nviwch, thank you again for your thoughtful attention to our work. It's worth noting that while reducing v by 64 times and increasing z by 64 times is equivalent to the original operation, omitting q_norm and k_norm introduces a distinction. This is because q_norm and k_norm is not a scalar, but a tensor representing the norm of each q and k. Our prior experiments indicated that forgoing these norms could potentially result in a decline in performance.
Hi @nviwch, we have fixed the numerical instability problem, and now the models can be trained with float16.
When I plug in FocusedLinearAttention and try to train in mixed precision. It got nan from the start. I found that the Denominator of z and the value of kv are quite large. It can easily explode in float16.
Another question, have you tried identity/convolution instead of depthwise convolution in the attention block?