tianbaochou / NasUnet

170 stars 45 forks source link

关于batch_intersection_union的计算 #24

Closed Yanxingang closed 4 years ago

Yanxingang commented 4 years ago

你好,nice work! 关于metrics.py中IOU的计算,这个地方是不是有错误:

def batch_intersection_union(output, target, nclass):
    predict = torch.max(output, 1)[1]
    mini = 1
    maxi = nclass-1
    nbins = nclass-1
    # label is: 0, 1, 2, ..., nclass-1
    # Note: 0 is background
    predict = predict.cpu().numpy().astype('int64') + 1
    target = target.cpu().numpy().astype('int64') + 1

    predict = predict * (target > 0).astype(predict.dtype)
    intersection = predict * (predict == target)

    # areas of intersection and union
    area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
    area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
    area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
    area_union = area_pred + area_lab - area_inter
    assert (area_inter <= area_union).all(), \
        "Intersection area should be smaller than Union area"
    return area_inter, area_union

这样求取的并不是IOU,主要问题,我认为是在 + 1这个地方。

tianbaochou commented 4 years ago

你好,nice work! 关于metrics.py中IOU的计算,这个地方是不是有错误:

def batch_intersection_union(output, target, nclass):
    predict = torch.max(output, 1)[1]
    mini = 1
    maxi = nclass-1
    nbins = nclass-1
    # label is: 0, 1, 2, ..., nclass-1
    # Note: 0 is background
    predict = predict.cpu().numpy().astype('int64') + 1
    target = target.cpu().numpy().astype('int64') + 1

    predict = predict * (target > 0).astype(predict.dtype)
    intersection = predict * (predict == target)

    # areas of intersection and union
    area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
    area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
    area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
    area_union = area_pred + area_lab - area_inter
    assert (area_inter <= area_union).all(), \
        "Intersection area should be smaller than Union area"
    return area_inter, area_union

这样求取的并不是IOU,主要问题,我认为是在 + 1这个地方。

你好。这个计算方法当时参考了其他的计算方式