Closed zaccharieramzi closed 3 years ago
Just to give more motivation, this allows me to do something like the following:
#### Final evaluation
data_module.predict_dataloader = data_module.val_dataloader
val_reconstructions = trainer.predict(datamodule=data_module)
outputs = defaultdict(list)
for val_recon, in_batch in zip(val_reconstructions, data_module.val_dataloader):
_, _, _, fname, slice_num, _, crop_size = in_batch
crop_size = crop_size[0] # always have a batch size of 1 for varnet
# detect FLAIR 203
if val_recon.shape[-1] < crop_size[1]:
crop_size = (val_recon.shape[-1], val_recon.shape[-1])
val_recon = T.center_crop(val_recon, crop_size)
outputs[fname].append((slice_num, val_recon))
# save outputs
for fname in outputs:
outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])
prediction_path = args.default_root_dir / f"{args.name}_reconstructions"
fastmri.save_reconstructions(outputs, prediction_path)
recons_key = (
"reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc"
)
args.prediction_path = prediction_path
args.target_path = data_module.data_path / f"{args.challenge}_val"
metrics = evaluate(args, recons_key)
print(metrics)
Ah interesting, I hadn't noticed the set_epoch
thing (tbh I didn't test this before because I figured it might be a problem).
This would allow the use of the
predict
method from the ptl trainer in order to evaluate the data volume-wise.I think it would also not cause any issue because the validation data does not need to be shuffled unlike the training data.
I didn't open an issue because this is such a small PR, but glad to if needed.