ivadomed / model-seg-dcm

Segmentation of lesions on MRI scans in patients with Degenerative Cervical Myelopathy (DCM)
MIT License
3 stars 0 forks source link

Supervised pre-training on SC segmentations using `swinunetr` #12

Open valosekj opened 3 months ago

valosekj commented 3 months ago

Description

This issue summarizes some early experiments with supervised pre-training SC segmentations.

WIP branch: nk/jv_vit_unetr_ssl Pre-training script: pretraining_and_finetuning/main_supervised_pretraining.py

Experiments

Unlike SSL experiments done in #7 and #9, the pre-training done in this issue is supervised, done on SC segmentations.

T2w images for the supervised pre-training come from 5 datasets (canproco, dcm-zurich, sci-colorado, sci-paris, spine-generic multi-subject). Number of training samples: 654. Number of validation samples: 163. Details about the images are provided in dataset-conversion/README.md.

I'm currently training two different swinunetr models. Both with crop_pad_size: [64, 160, 320] and patch_size: [64, 64, 64].

Experiment 1 - Model with CropForegroundd - multiple datasets

This model uses transforms.CropForegroundd(keys=all_keys, source_key="label_sc") to crop everything outside the SC mask. See pretraining_and_finetuning/transforms.py.

GIF of validation samples ![gif](https://github.com/ivadomed/model-seg-dcm/assets/39456460/ec6cdfc6-d19c-47e2-b39a-0c55a9063b8c) Note: validation is done every 5 epochs

Validation hard dice dropped to zero after ~100 epochs:

loss_plots ![loss_plots](https://github.com/ivadomed/model-seg-dcm/assets/39456460/ce2317dc-8ddf-4eed-91db-35f7c7102790) Note: validation is done every 5 epochs

Experiment 2 - Model without CropForegroundd - multiple datasets

This model does NOT use transforms.CropForegroundd(keys=all_keys, source_key="label_sc").

GIF of validation samples ![gif](https://github.com/ivadomed/model-seg-dcm/assets/39456460/647e5d23-77fd-4212-b0fc-a3715614328b) Note: validation is done every 5 epochs
loss_plots ![loss_plots](https://github.com/ivadomed/model-seg-dcm/assets/39456460/21759707-d976-4d6b-b49a-61c4382e704e) Note: validation is done every 5 epochs

Model training crashed due to OSErrror: [Errno 112] Host is down ... (possibly because I'm still using duke/temp to load data from?). So I resumed the training from the best checkpoint (~65 epoch). Training resumed but then the validation hard dice dropped to zero:

loss_plots after resume ![loss_plots](https://github.com/ivadomed/model-seg-dcm/assets/39456460/b21222b6-ada7-430d-8577-4d7c890f8789) Note: validation is done every 5 epochs
valosekj commented 3 months ago

Since both swinunetr with and without CropForegroundd crashed to zero when trained on T2w images from multiple datasets (details in the comment above), I tried to train swinunetr on a single dataset (spine-generic multi-subject). And training finished successfully!

Experiment 3 - Model with CropForegroundd - spine-generic only

This model used transforms.CropForegroundd(keys=all_keys, source_key="label_sc"). Number of training samples: 213. Number of validation samples: 54.

loss_plots ![loss_plots](https://github.com/ivadomed/model-seg-dcm/assets/39456460/29e02d37-bd1a-4b39-813a-62fdd94b76ba) Note: validation is done every 5 epochs
GIF of validation samples ![gif](https://github.com/ivadomed/model-seg-dcm/assets/39456460/8710f610-b384-485e-a60d-ac92b2ee66ec) Note: validation is done every 5 epochs

Experiment 4 - Model without CropForegroundd - spine-generic only

This model did NOT use transforms.CropForegroundd(keys=all_keys, source_key="label_sc"). Number of training samples: 213. Number of validation samples: 54.

loss_plots ![loss_plots](https://github.com/ivadomed/model-seg-dcm/assets/39456460/7a54b4c1-ecfd-4aad-ac5f-6651054d3982) Note: validation is done every 5 epochs
GIF of validation samples ![gif](https://github.com/ivadomed/model-seg-dcm/assets/39456460/c4e0b713-7b50-464b-bda0-e7a29ad75b25) Note: validation is done every 5 epochs
Notice that the model is predicting beside SC also other components, for example: ![val_00199_048 copy](https://github.com/ivadomed/model-seg-dcm/assets/39456460/43025d4a-cdd3-4d5e-8eee-741286b7e713)

Conclusion

I originally thought that collapsing the training to zero was due to using transforms.CropForegroundd(keys=all_keys, source_key="label_sc"). But when I trained only on spine-generic, both training with and without transforms.CropForegroundd finished successfully; it seems that training collapsing may have originated from using images from multiple images.

naga-karthik commented 3 months ago

it seems that training collapsing may have originated from using images from multiple images.

I kind of don't agree with this because I have trained on spine-generic and basel-mp2rage for contrast-agnostic and it worked fine. This crashing you report on multiple datasets might be an issue with the specific experiment -- once the training stopped and resumed from checkpoint -- there might have been an issue with loading the checkpoint and resuming training.

if we compare: (1) spine-generic with CropForegroundd, (2) spine-generic + lesion datasets with CropForegroundd, while ensuring that the training did not crash at any point -- we might have different conclusion!

valosekj commented 3 months ago

Thanks @naga-karthik!

(2) spine-generic + lesion datasets with CropForegroundd

I tried the following experiment:

swinunetr with CropForegroundd pre-trained on three datasets (spine-generic multi-subject, dcm-zurich, and sci-paris) for SC seg. And training finished without any crashes!

loss_plots ![loss_plots](https://github.com/ivadomed/model-seg-dcm/assets/39456460/9e5e9f26-bd69-46f7-803b-96a904f15619) Note: validation is done every 5 epochs
GIF of validation samples ![gif](https://github.com/ivadomed/model-seg-dcm/assets/39456460/f81b8db5-e275-4be2-8a1b-5c7b656be525) Note: validation is done every 5 epochs

So, now we have several pre-trained models, I'm moving to fine-tuning on lesions!


btw, hard to say what was the origin of training crashing in https://github.com/ivadomed/model-seg-dcm/issues/12#issue-2228472084. I'll try to figure this out later.

naga-karthik commented 3 months ago

Do you also have some pre-trained nnunet or monai-unet models?