Open TengliEd opened 1 year ago
I meet the same question
@TengliEd Hello, it seems like I've encountered some issues during the evaluation. Could you please tell me what this error message is? Has the following error occurred?
aceback (most recent call last):
File "train.py", line 279, in
File "train.py", line 279, in
Root Cause (first observed failure): [0]: time : 2023-11-12_14:41:17 host : I1637d69a3200801799 rank : 0 (local_rank: 0) exitcode : 1 (pid: 21242) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
I believe it should be torch.cat(pred_list, 0) instead of torch.cat(pred_list, 1)
According to https://github.com/tianrun-chen/SAM-Adapter-PyTorch/blob/main/utils.py#L122
def calc_cod(y_pred, y_true):
batchsize = y_true.shape[0]
metric_FM = sod_metric.Fmeasure()
metric_WFM = sod_metric.WeightedFmeasure()
metric_SM = sod_metric.Smeasure()
metric_EM = sod_metric.Emeasure()
metric_MAE = sod_metric.MAE()
with torch.no_grad():
assert y_pred.shape == y_true.shape
for i in range(batchsize):
true, pred = \
y_true[i, 0].cpu().data.numpy() * 255, y_pred[i, 0].cpu().data.numpy() * 255
The code iterate over batch_size, torch.cat(pred_list, 1) would only consider 1 batch samples. But the code means evaluate on all the validation set (len(val set)).
@tianrun-chen In func eval_psnr(),
pred_list = torch.cat(pred_list, 1) gt_list = torch.cat(gt_list, 1)
should bepred_list = torch.cat(pred_list, 0) gt_list = torch.cat(gt_list, 0)