Z-Zheng / SimpleCV

Simplify training, evaluation, prediction in Pytorch
MIT License
63 stars 12 forks source link

the function of "th_intersection_over_union_per_class" in metric.py #170

Closed ggyyzm closed 3 years ago

ggyyzm commented 3 years ago

学长好!请问下 def th_intersection_over_union_per_class(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): cm_th = th_confusion_matrix(y_true, y_pred, num_classes) sum_over_row = cm_th.sum(dim=0) sum_over_col = cm_th.sum(dim=0) diag = cm_th.diag() denominator = sum_over_row + sum_over_col - diag iou_per_class = diag / denominator return iou_per_class 这一段的 sum_over_col = cm_th.sum(dim=0) 是不是有问题呀? 应该是 dim=1吧 不好意思刚不小心关闭了提问,重新发一下~

Z-Zheng commented 3 years ago

感谢指正。 对,应该是dim=1, 不过这个函数很早就被弃用,可以使用metric/pixel.py里的进行计算。 抱歉带来困扰。

ggyyzm commented 3 years ago

非常感谢!