xinge008 / Cylinder3D

Rank 1st in the leaderboard of SemanticKITTI semantic segmentation (both single-scan and multi-scan) (Nov. 2020) (CVPR2021 Oral)
Apache License 2.0
858 stars 180 forks source link

About the weights in WCE #169

Open xpzbph opened 1 year ago

xpzbph commented 1 year ago

Hello, first of all thank you for your work, I looked at the code you posted and did not find the weights you set for each category in the weighted cross entropy loss function, I wonder if you can explain it, looking forward to your reply

L-Reichardt commented 1 year ago

@xpzbph As an alternative, you could use the weights of 2DPass. I am using that weighting with good results, summarized in the following function.

def weights():
    seg_num_per_class = [
        0,
        55437630,
        320797,
        541736,
        2578735,
        3274484,
        552662,
        184064,
        78858,
        240942562,
        17294618,
        170599734,
        6369672,
        230413074,
        101130274,
        476491114,
        9833174,
        129609852,
        4506626,
        1168181,
    ]

    seg_labelweights = seg_num_per_class / np.sum(seg_num_per_class)
    seg_labelweights = np.power(
        np.amax(seg_labelweights) / (seg_labelweights + 1e-8), 1 / 3.0
    )
    seg_labelweights = torch.Tensor(seg_labelweights)

    # reduce INF for 'unlabeled' to weight = 0
    seg_labelweights[seg_labelweights == float("Inf")] = 0
    return seg_labelweights
xpzbph commented 1 year ago

@L-Reichardt Thank you for your answer. I have two more questions about the function you provided.

  1. First, does the array of items inside seg_num_per_class represent the total number of points of each class in the training sequence? 2, it is reasonable to say that the weight of unlabeled should be 0, but according to the function you provided to calculate the weight of unlabeled is not 0? I hope you can answer my question as soon as possible, thank you very much!
L-Reichardt commented 1 year ago

@xpzbph

  1. I presume so. This is from 2DPass not me. Alternatively the KITTI .yaml file also contains the ratios of each point compared to the amount of points in the datatset.
  2. You can set it to 0 if you want. I am using 0 as the "ignore" class in my loss functions, for this reason it does not matter.
xpzbph commented 1 year ago

@L-Reichardt Thank you for your reply, I will continue to make an attempt

xpzbph commented 1 year ago

@L-Reichardt Hello, sorry to bother you again, I trained the network with RTX2080ti and used the weight calculation function you provided, but the accuracy decreased after 40 epoch and the loss on the validation set increased, I would like to ask you what is the situation after training? I look forward to your reply.