Closed kHarshit closed 5 years ago
The problem was that UNet requires input as 572x572, but outputs mask 388x388, so I had to convert the predicted labels from 388 -> 572 by padding (as suggested by https://github.com/meetshah1995/pytorch-semseg/issues/43#issuecomment-406119804) as follows:
# In /ptsemseg/metrics.py
def update(self, label_trues, label_preds):
for lt, lp in zip(label_trues, label_preds):
# print(lt.shape, lp.shape) # (572, 572), (388, 388)
lp = np.pad(lp,((92,92),(92,92)), mode='reflect')
# print(lt.shape, lp.shape) # (572, 572), (572, 572)
self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
I'm trying to train UNet on custom dataset. The training runs fine, but the following error occurs during validation step (when
val_interval: 50
is reached).The
config
file is as follows:The error is: