Closed ggyyzm closed 3 years ago
学长好!请问下 def th_intersection_over_union_per_class(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): cm_th = th_confusion_matrix(y_true, y_pred, num_classes) sum_over_row = cm_th.sum(dim=0) sum_over_col = cm_th.sum(dim=0) diag = cm_th.diag() denominator = sum_over_row + sum_over_col - diag iou_per_class = diag / denominator return iou_per_class 这一段的 sum_over_col = cm_th.sum(dim=0) 是不是有问题呀? 应该是 dim=1吧 不好意思刚不小心关闭了提问,重新发一下~
def th_intersection_over_union_per_class(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): cm_th = th_confusion_matrix(y_true, y_pred, num_classes) sum_over_row = cm_th.sum(dim=0) sum_over_col = cm_th.sum(dim=0) diag = cm_th.diag() denominator = sum_over_row + sum_over_col - diag iou_per_class = diag / denominator return iou_per_class
sum_over_col = cm_th.sum(dim=0)
dim=1
感谢指正。 对,应该是dim=1, 不过这个函数很早就被弃用,可以使用metric/pixel.py里的进行计算。 抱歉带来困扰。
非常感谢!
学长好!请问下
def th_intersection_over_union_per_class(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): cm_th = th_confusion_matrix(y_true, y_pred, num_classes) sum_over_row = cm_th.sum(dim=0) sum_over_col = cm_th.sum(dim=0) diag = cm_th.diag() denominator = sum_over_row + sum_over_col - diag iou_per_class = diag / denominator return iou_per_class
这一段的sum_over_col = cm_th.sum(dim=0)
是不是有问题呀? 应该是dim=1
吧 不好意思刚不小心关闭了提问,重新发一下~