vdurnov / xview2_1st_place_solution

1st place solution for "xView2: Assess Building Damage" challenge.
MIT License
84 stars 42 forks source link

Question about validate function of classification stage #8

Open Kittywyk opened 3 years ago

Kittywyk commented 3 years ago

Hi, I have a question about the validate function of classification stage. The codes are in def validate(net,data_loader), train50_cls_cce.py. `
for j in range(msks.shape[0]): # msks.shape[0] is batch size tp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] > 0).sum() fn[4] += np.logical_and(msks[j, 0] < 1, msk_pred[j] > 0).sum() fp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] < 1).sum()

            targ = lbl_msk[j][msks[j, 0] > 0]
            pred = msk_damage_pred[j].argmax(axis=0)
            pred = pred * (msk_pred[j] > _thr)
            pred = pred[msks[j, 0] > 0]
            for c in range(4):
                tp[c] += np.logical_and(pred == c, targ == c).sum()
                fn[c] += np.logical_and(pred != c, targ == c).sum()
                fp[c] += np.logical_and(pred == c, targ != c).sum()

` I was wondering why there is 'pred == c, targ == c' , when the values of c ranges from 0 to 3 while the values of targ ranges from 1 to 4? Or did I get it wrong? Please explain to me. Thanks a million!