MedARC-AI / fMRI-reconstruction-NSD

fMRI-to-image reconstruction on the NSD dataset.
MIT License
294 stars 39 forks source link

Add FID scoring #3

Closed jimgoo closed 1 year ago

jimgoo commented 1 year ago

Add FID scoring following the example here: https://github.com/lucidrains/DALLE2-pytorch/blob/3b2cf7b0bc152d826f74a90f5f6b922a8b9f7b21/train_decoder.py#L210

The utils.sample_images method will add all original input images to the real FID set. By default the number of these images is n_sample_save = 8. Each brain-guided SD image variation is added to the fake FID set. There are n_sample_save x 4 = 32 fake images since we generate 4 per original image. The resulting FIDs for training and validation sets are logged to wandb.

Since this method is only called on the master process for now, I also added another param to control how often the sampling gets done since sampling takes so long. Ideally we'll migrate the code from this method into the actual pytorch diffusion_prior model so that it can be called by all processes and not slow down training.

These FID scores will have some variation between runs since the training images get shuffled, will look into setting that seed inside webdataset.

Also added back an old notebook that I'd deleted which shows some of the early results.

PaulScotti commented 1 year ago

Great job, subsequent changes will be needed regarding the n_samples_save = 8 part since FID requires more samples, and maybe offloading this validation sampling across the gpus