LeapLabTHU / FLatten-Transformer

Official repository of FLatten Transformer (ICCV2023)
377 stars 21 forks source link

Numerical Instability in z & kv #15

Closed nviwch closed 8 months ago

nviwch commented 10 months ago

When I plug in FocusedLinearAttention and try to train in mixed precision. It got nan from the start. I found that the Denominator of z and the value of kv are quite large. It can easily explode in float16.

Another question, have you tried identity/convolution instead of depthwise convolution in the attention block?

tian-qing001 commented 8 months ago

Hi @nviwch, thank you for your interest in our work. Training with mixed precision occasionally results in NaN values on certain devices. To troubleshoot this issue, consider disabling mixed precision in the Flatten module. For instance, you can replace the forward function in Flatten-Swin with the following code:

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
        q, k, v = qkv.unbind(0)
        k = k + self.positional_encoding
        focusing_factor = self.focusing_factor
        kernel_function = nn.ReLU()
        q = kernel_function(q) + 1e-6
        k = kernel_function(k) + 1e-6
        scale = nn.Softplus()(self.scale)
        q = q / scale
        k = k / scale
        q_norm = q.norm(dim=-1, keepdim=True)
        k_norm = k.norm(dim=-1, keepdim=True)
        if float(focusing_factor) <= 6:
            q = q ** focusing_factor
            k = k ** focusing_factor
        else:
            q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factor
            k = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factor
        q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
        k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
        q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
        i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]

        with torch.autocast(enabled=False, device_type='cuda'):
            q, k, v = q.float(), k.float(), v.float()
            z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
            if i * j * (c + d) > c * d * (i + j):
                kv = torch.einsum("b j c, b j d -> b c d", k, v)
                x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
            else:
                qk = torch.einsum("b i c, b j c -> b i j", q, k)
                x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)

        num = int(v.shape[1] ** 0.5)
        feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
        feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
        x = x + feature_map

        x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Currently, we haven't experimented with identity/convolution as an alternative to depthwise convolution in the attention block, which could be a valuable avenue for future investigation and improvement.

nviwch commented 8 months ago

我在調試過程發現計算kv 時,取代v成v/64,z變成z*64,epsilon=1e-3可以避免float16數值爆掉,而數式上是等價的,只是分子分母各自倍大了 Screenshot 2023-12-19 094216 另外也發現你把原本q_norm, k_norm乘成q和k,好像是沒有分別的,等同分子分母一同倍大。

以下是我逐行簡化,保留了你們的想法,也為了加速減少了許多

class SimplifiedFocusedLinearAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.gen_qkv = nn.Conv1d(dim, dim*3, 1, bias=False)
        self.proj = nn.Conv1d(dim, dim, 1, bias=False)

    def forward(self, x): # 輸入形狀是B, C, N,由圖像B, C, H, W變成,為了速度不作transpose
        B, _, N = x.shape
        qkv = self.gen_qkv(x)
        qkv.relu_() # 改成 qkv都有非線性,因為下面我直接加v

        q, k, v = torch.split(qkv, self.dim, 1) # k: (b, c, n) v: (b, c, n)

        # 只能用二次方,三次方怕爆炸
        # 這樣的話,有點像Squared ReLU
        q = f.normalize(q.square()) 
        k = f.normalize(k.square())

        q = q.reshape((-1, self.head_dim, N))
        k = k.reshape((-1, self.head_dim, N))
        v = v.reshape((-1, self.head_dim, N))

        z = (q.transpose(1,2) @ k.sum(dim=-1)[...,None])[...,0] + 1e-3
        z = z.reciprocal().mul(64) # 平衡乘法 1e3 * 64 < 65504 (float16極限)

        kv = (v/64) @ k.transpose(1,2) # 避免 kv 爆炸
        x = kv @ q * z[:,None,:] + v # 改成加v而已

        x = x.reshape((B, -1, N))

        x = self.proj(x)
        return x
tian-qing001 commented 8 months ago

Hi @nviwch, thank you again for your thoughtful attention to our work. It's worth noting that while reducing v by 64 times and increasing z by 64 times is equivalent to the original operation, omitting q_norm and k_norm introduces a distinction. This is because q_norm and k_norm is not a scalar, but a tensor representing the norm of each q and k. Our prior experiments indicated that forgoing these norms could potentially result in a decline in performance.

tian-qing001 commented 3 months ago

Hi @nviwch, we have fixed the numerical instability problem, and now the models can be trained with float16.