LeapLabTHU / FLatten-Transformer

Official repository of FLatten Transformer (ICCV2023)
388 stars 23 forks source link

Numerical Instability #18

Closed feiwushiwo closed 4 months ago

feiwushiwo commented 10 months ago

metrice_curve 我想问实验效果真的很一般而且也不清楚为什么会这样 loss_curve metrice_curve PR_curve

feiwushiwo commented 10 months ago

confusion_matrix

tian-qing001 commented 10 months ago

Hi @feiwushiwo, it appears that you are encountering numerical instability issues, possibly resulting in NaN values. We recommend turning off amp or referring to https://github.com/LeapLabTHU/FLatten-Transformer/issues/15#issuecomment-1859776961 for guidance on addressing and mitigating this problem.

feiwushiwo commented 10 months ago

amp是关着的后面代码修改的时候出现了问题class FocusedLinearAttention(nn.Module): def init(self, dim, resolution, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0., qk_scale=None, focusing_factor=3, kernel_size=5): super().init() self.dim = dim self.dim_out = dim_out or dim self.resolution = resolution self.split_size = split_size self.num_heads = num_heads head_dim = dim // num_heads

NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights

    # self.scale = qk_scale or head_dim ** -0.5
    H_sp, W_sp = self.resolution[0], self.resolution[1]
    self.H_sp = H_sp
    self.W_sp = W_sp
    stride = 1
    self.conv_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False)
    self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)

    self.attn_drop = nn.Dropout(attn_drop)

    self.focusing_factor = focusing_factor
    self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                         groups=head_dim, padding=kernel_size // 2)
    self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
    self.positional_encoding = nn.Parameter(torch.zeros(size=(1, self.H_sp * self.W_sp, dim)))

def im2cswin(self, x):
    B, N, C = x.shape
    H = W = int(np.sqrt(N))
    x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
    x = img2windows(x, self.H_sp, self.W_sp)
    # x = x.reshape(-1, self.H_sp * self.W_sp, C).contiguous()
    return x

def get_lepe(self, x, func):
    B, N, C = x.shape
    H = W = int(np.sqrt(N))
    x = x.transpose(-2, -1).contiguous().view(B, C, H, W)

    H_sp, W_sp = self.H_sp, self.W_sp
    x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
    x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp)  ### B', C, H', W'

    lepe = func(x)  ### B', C, H', W'
    lepe = lepe.reshape(-1, C // self.num_heads, H_sp * W_sp).permute(0, 2, 1).contiguous()

    x = x.reshape(-1, C, self.H_sp * self.W_sp).permute(0, 2, 1).contiguous()
    return x, lepe

def forward(self, qkv):
    """
    x: B C H W
    """
    qkv = self.conv_qkv(qkv)
    q, k, v = torch.chunk(qkv.flatten(2).transpose(1, 2), 3, dim=-1)

    ### Img2Window
    H, W = self.resolution
    B, L, C = q.shape
    assert L == H * W, "flatten img_tokens has wrong size"

    q = self.im2cswin(q)
    k = self.im2cswin(k)
    v, lepe = self.get_lepe(v, self.get_v)

    k = k + self.positional_encoding
    focusing_factor = self.focusing_factor
    kernel_function = nn.ReLU()

    q = kernel_function(q) + 1e-6
    k = kernel_function(k) + 1e-6
    scale = nn.Softplus()(self.scale)
    q = q / scale
    k = k / scale
    q_norm = q.norm(dim=-1, keepdim=True)
    k_norm = k.norm(dim=-1, keepdim=True)
    if float(focusing_factor) <= 6:
        q = q ** focusing_factor
        k = k ** focusing_factor
    else:
        q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factor
        k = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factor
    q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
    k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
    q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
    i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
    with torch.autocast(enabled=False, device_type='cuda'):
        q, k, v = q.float(), k.float(), v.float()
        z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
        if i * j * (c + d) > c * d * (i + j):
            kv = torch.einsum("b j c, b j d -> b c d", k, v)
            x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
        else:
            qk = torch.einsum("b i c, b j c -> b i j", q, k)
            x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
    num = int(v.shape[1] ** 0.5)
    feature_map = rearrange(v, "b (h w) c -> b c h w", w=num, h=num)
    feature_map = feature_map.float()
    feature_map = rearrange(self.dwc(feature_map), "b c h w -> b (h w) c")
    x = x + feature_map
    x = x + lepe
    x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
    x = windows2img(x, self.H_sp, self.W_sp, H, W).permute(0, 3, 1, 2)
    return x代码如上错误为

image 麻烦博主请您帮我看一下错误原因

tian-qing001 commented 9 months ago

There is an inconsistency in the data types – the model weight type is Half, while the input tensor type is float. I recommend reviewing and adjusting the model weight type to ensure compatibility.

tian-qing001 commented 4 months ago

We have fixed numerical instability problem, and now the model can be trained with auto mixed precision (amp).