Closed ggyyzm closed 3 years ago
学长好!请问下 def th_intersection_over_union_per_class(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): cm_th = th_confusion_matrix(y_true, y_pred, num_classes) sum_over_row = cm_th.sum(dim=0) sum_over_col = cm_th.sum(dim=0) diag = cm_th.diag() denominator = sum_over_row + sum_over_col - diag iou_per_class = diag / denominator return iou_per_class 这一段的 sum_over_col = cm_th.sum(dim=0) 是不是有问题呀? 应该是 dim=1吧~
def th_intersection_over_union_per_class(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): cm_th = th_confusion_matrix(y_true, y_pred, num_classes) sum_over_row = cm_th.sum(dim=0) sum_over_col = cm_th.sum(dim=0) diag = cm_th.diag() denominator = sum_over_row + sum_over_col - diag iou_per_class = diag / denominator return iou_per_class
sum_over_col = cm_th.sum(dim=0)
dim=1
学长好!请问下
def th_intersection_over_union_per_class(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): cm_th = th_confusion_matrix(y_true, y_pred, num_classes) sum_over_row = cm_th.sum(dim=0) sum_over_col = cm_th.sum(dim=0) diag = cm_th.diag() denominator = sum_over_row + sum_over_col - diag iou_per_class = diag / denominator return iou_per_class
这一段的sum_over_col = cm_th.sum(dim=0)
是不是有问题呀? 应该是dim=1
吧~