dvlab-research / PanopticFCN

Fully Convolutional Networks for Panoptic Segmentation (CVPR2021 Oral)
Apache License 2.0
393 stars 53 forks source link

some questions in code understanding #13

Closed Jing-W98 closed 3 years ago

Jing-W98 commented 3 years ago

Hello, I I didn't understand the following code. What do the "pred_cate" and "pred_score" mean? and “cate_matrix = pred_cate.unsqueeze(-1) == pred_cate.unsqueeze(0)” means what? Thank you. def kernel_fusion(self, meta_weight, pred_cate, pred_score): meta_weight = meta_weight.squeeze(0) similarity = self.cal_similarity(meta_weight, meta_weight, sim_type=self.sim_type) label_matrix = similarity.triu(diagonal=0) >= self.sim_thres if self.class_spec: cate_matrix = pred_cate.unsqueeze(-1) == pred_cate.unsqueeze(0) label_matrix = label_matrix & cate_matrix cum_matrix = torch.cumsum(label_matrix.float(), dim=0) < 2 keep_matrix = cum_matrix.diagonal(0) label_matrix = (label_matrix[keep_matrix] & cum_matrix[keep_matrix]).float() label_norm = label_matrix.sum(dim=1, keepdim=True) meta_weight = torch.mm(label_matrix, meta_weight) / label_norm pred_cate = pred_cate[keep_matrix] pred_score = pred_score[keep_matrix] return meta_weight, pred_cate, pred_score