if inffer:
l = 0
model.eval()
with torch.inference_mode():
for step,data in enumerate(val_data):
images, labels = data
labels[labels > 0] = 1
labels = torch.Tensor(labels).long().to(device)
pred = model(images.to(device))
pred,labels=pred.cpu(),labels.cpu()
for i in range(images.size(0)):
pilImage = to_pil_image(images[i])
pilInfer = to_pil_image(pred[i])
pilImage.save(f'{infferPredictPath}/1/image{l}.png')
pilInfer.save(f'{infferPredictPath}/1/infer{l}.png')
l += 1
return
我在train.py里加了这么一段代码试图将原始图片和推理结果保存起来,但结果却是这样的