MedARC-AI / fMRI-reconstruction-NSD

fMRI-to-image reconstruction on the NSD dataset.
MIT License
280 stars 39 forks source link

Train/val split counts and DDP tweaks for cluster #2

Closed jimgoo closed 1 year ago

jimgoo commented 1 year ago

For the train/val split sample counts for the NSD webdataset , we've been using num_train = 24983 and num_val = 492 in the code. However these apply only to dataset commits up to 9947586, before @PaulScotti updated in January. After that commit, there is a metadata file that contains these counts. The stability cluster has the latest version of this dataset. I've updated utils.get_dataloaders so that it accepts an URL for this metadata file. If no URL is given, it falls back to the old counts. This method also returns these counts so that they can be used downstream for things like the OneCycleLR.

For DDP, I added a train_combo.py script which is basically the old train_prior_with_voxel2clip.py updated to work with DDP. The main thing is to use the unwrapped pytorch models via model.module during validation on the master process to prevent hanging. I also made the validation DataLoader use the single device batch size and only one worker since it runs on the master process, this way we the full set of validation samples each epoch.

The reason for the new train_combo.py script instead of using the old one is that I was developing on the cluster with slurm, so running a notebook that gets converted to a python file wasn't feasible. If I modified the converted python script, there would be no way to get the changes back into the notebook. I made a best effort to update the notebook so that it at least works, but train_combo.py is the only thing I've tested.

Other random things: