facebookresearch / fastMRI

A large-scale dataset of both raw MRI measurements and clinical MRI images.
https://fastmri.org
MIT License
1.3k stars 372 forks source link

No Shuffle for samplers and loaders in data module for test and validation #153

Closed zaccharieramzi closed 3 years ago

zaccharieramzi commented 3 years ago

This would allow the use of the predict method from the ptl trainer in order to evaluate the data volume-wise.

I think it would also not cause any issue because the validation data does not need to be shuffled unlike the training data.

I didn't open an issue because this is such a small PR, but glad to if needed.

zaccharieramzi commented 3 years ago

Just to give more motivation, this allows me to do something like the following:

#### Final evaluation
        data_module.predict_dataloader = data_module.val_dataloader
        val_reconstructions = trainer.predict(datamodule=data_module)
        outputs = defaultdict(list)
        for val_recon, in_batch in zip(val_reconstructions, data_module.val_dataloader):
            _, _, _, fname, slice_num, _, crop_size = in_batch
            crop_size = crop_size[0]  # always have a batch size of 1 for varnet
            # detect FLAIR 203
            if val_recon.shape[-1] < crop_size[1]:
                crop_size = (val_recon.shape[-1], val_recon.shape[-1])
            val_recon = T.center_crop(val_recon, crop_size)
            outputs[fname].append((slice_num, val_recon))
        # save outputs
        for fname in outputs:
            outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

        prediction_path = args.default_root_dir / f"{args.name}_reconstructions"
        fastmri.save_reconstructions(outputs, prediction_path)
        recons_key = (
            "reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc"
        )
        args.prediction_path = prediction_path
        args.target_path = data_module.data_path / f"{args.challenge}_val"
        metrics = evaluate(args, recons_key)
        print(metrics)

Right after https://github.com/facebookresearch/fastMRI/blob/master/fastmri_examples/varnet/varnet_knee_leaderboard_20201111.py#L73

zaccharieramzi commented 3 years ago

Ah interesting, I hadn't noticed the set_epoch thing (tbh I didn't test this before because I figured it might be a problem).