FVL2020 / ICCV-2023-MB-TaylorFormer

84 stars 9 forks source link

my data for train got loss:nan #4

Closed akbarxie closed 10 months ago

akbarxie commented 10 months ago
hello,I try to use the net for image segmentation, and use the L1 loss, also try dice loss.but I get loss for **nan**. For samll data set it got work well.
but when I use the lage img for segmentation,it failed for loss is nan.

the image like this: 11156 and its mask: 11156

akbarxie commented 10 months ago

we try to calculate the max and min value of the output tensor: predict:-inf & min -inf target:1.0 & min0.0 loss ce: tensor(nan, device='cuda:0', grad_fn=) is nan: ['/home/lthpc/Algo/xjw/ICCV-2023-MB-TaylorFormer/Datasets_seg_all_data/train/input/12439.png']

raindrop313 commented 10 months ago

I guess the reason for the nan is the oversize of the image, the code for the T-MSA is shown below:

""" out_numerator = torch.sum(v, dim=-2).unsqueeze(2)+(q@attn) out_denominator = torch.full((hw,c//self.numheads),hw).to(q.device)+q@torch.sum(k,dim=-1).unsqueeze(3).repeat(1,1,1,c//self.num heads)+1e-6 """

When the scale of the image is too large, it will make the terms "torch.sum(k,dim=-1)" and "torch.sum(v, dim=-2)" too large, and if the value is larger than the boundary of FP32, it will make the gradient explode. which in turn makes the training crash, one solution is to scale q, k, for example:

""" q_norm=torch.norm(q,p=2,dim=-1,keepdim=True)/self.norm q=torch.div(q,q_norm) k_norm=torch.norm(k,p=2,dim=-2,keepdim=True)/self.norm k=torch.div(k,k_norm)

refine_weight = self.refine_att(q,k, v, size=(h, w))

refine_weight = self.sigmoid(refine_weight) q=q/100 k=k/100 attn = k@v

out_numerator = torch.sum(v, dim=-2).unsqueeze(2)/10000+(q@attn) out_denominator = torch.full((hw,c//self.num_heads),hw).to(q.device)/10000+q@torch.sum(k,dim=-1).unsqueeze(3).repeat(1,1,1,c// self.num_heads)+1e-6 """

Good luck in solving this problem!

raindrop313 commented 10 months ago

@akbarxie You can also choose the appropriate scaling factor according to the size of your image

akbarxie commented 10 months ago

@raindrop313 after print log of the , II found that : ''' print("self.norm:", self.norm)

    print("k step 0", k[0, 0, 0, 0:10], "max:", torch.max(k), " min:", torch.min(k))

    k_norm=torch.norm(k,p=2,dim=-2,keepdim=True)/self.norm

    print("k_norm step 1", k_norm[0, 0, 0, 0:10], "max:", torch.max(k_norm), " min:", torch.min(k_norm))

    k=torch.div(k,k_norm)

    print("k step 2", k[0, 0, 0, 0:10], "max:", torch.max(k), " min:", torch.min(k))

    refine_weight = self.refine_att(q,k, v, size=(h, w))
    #refine_weight=self.Leakyrelu(refine_weight)
    refine_weight = self.sigmoid(refine_weight)
    #attn = k@v
    ##attn = attn.softmax(dim=-1)

    #print(torch.sum(k, dim=-1).unsqueeze(3).shape)
    #out_numerator = torch.sum(v, dim=-2).unsqueeze(2)+(q@attn)
    #out_denominator = torch.full((h*w,c//self.num_heads),h*w).to(q.device)\
    #                  +q@torch.sum(k, dim=-1).unsqueeze(3).repeat(1,1,1,c//self.num_heads)+1e-6

    # add new add norm

    q = q / 100
    k = k / 100
    attn = k@v

    print("k", k[0, 0, 0, 0:10], "max:", torch.max(k), " min:",  torch.min(k))
    print("v", v[0, 0, 0, 0:10], "max:", torch.max(v), " min:",  torch.min(v))

    #print("q", q[0, 0, 0, 0:10])
    print("attn", attn[0, 0, 0, 0:10])

    #print("v shape:", v.shape) #[1, 1, 65536, 24]
    # print("v", v[0, 0, 0, :])
    # a_tmp = torch.sum(v, dim=-2)
    # print("a_tmp", a_tmp[0, 0, :])

    #a_tmp_div_1w = torch.sum(v/10000, dim=-2)
    #print("a_tmp_div_1w", a_tmp_div_1w[0, 0, 0:10])

    #a_mul_attn =q @ attn
    #print("a_mul_attn", a_mul_attn[0, 0, 0, 0:10])

    out_numerator = torch.sum(v/10000, dim=-2).unsqueeze(2) + (q@attn)
    out_denominator = torch.full((h*w, c // self.num_heads), h*w).to(q.device) / 10000 \
                      +q@torch.sum(k, dim=-1).unsqueeze(3).repeat(1, 1, 1, c // self.num_heads) + 1e-6

'''

there will be zero in tensor(k_norm step 1 , so should I add 1e-4 like value to it ??? @raindrop313 ) : ''' self.norm: 0.5 k step 0 [ 0.1113, -0.1569, -0.1257, -0.0193, -0.0194, -0.0203, -0.0177, -0.0179, -0.0203, -0.0195], max: 2.3573, min: -3.4106 k_norm step 1 [5.2678, 7.9936, 8.1658, 8.1649, 8.1578, 8.1560, 8.1526, 8.1599, 8.1615, 8.1608],max: 11.3950 ### min: 0. k step 2 tensor([ 0.0211, -0.0196, -0.0154, -0.0024, -0.0024, -0.0025, -0.0022, -0.0022, -0.0025, -0.0024] max: nan min: nan k [ 2.1123e-04, -1.9629e-04, -1.5397e-04, -2.3631e-05, -2.3838e-05, -2.4833e-05, -2.1759e-05, -2.1968e-05, -2.4881e-05, -2.3847e-05],max: nan, min: nan v tensor([-0.0448, -0.0507, -0.1200, -0.2497, 0.0207, 0.7732, 0.0228, -0.0495, -0.0430, 0.0667] max: 1.4037, min: -2.2425 attn [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan] '''

akbarxie commented 10 months ago

I added 1 and the problem disappeared, and hope that the problem not appear again

raindrop313 commented 10 months ago

@akbarxie You are right, the code uploaded is an older version and needs to be changed to the following code: """ q_norm=torch.norm(q,p=2,dim=-1,keepdim=True)/self.norm+1e-6 q=torch.div(q,q_norm) k_norm=torch.norm(k,p=2,dim=-2,keepdim=True)/self.norm+1e-6 k=torch.div(k,k_norm) """

akbarxie commented 10 months ago

thank you!! @raindrop313