sct-pipeline / contrast-agnostic-softseg-spinalcord

Contrast-agnostic spinal cord segmentation project with softseg
MIT License
4 stars 3 forks source link

Training SynthSeg on spine-generic data #111

Open naga-karthik opened 5 months ago

naga-karthik commented 5 months ago

This issue summarizes my attempt at training the SynthSeg on the spine-generic multi subject dataset.

Brief context: SynthSeg was originally proposed for the segmentation of brain scans of any resolution and contrast. Because is "contrast-agnostic", I am trying to re-train the model as a baseline for our contrast-agnostic spinal cord (SC) segmentation model. Important note here is that SynthSeg required fully-labeled brain scans from which it synthesizes fake brain images (sampled from a GMM) and then trains the segmentation model. A notable challenge in re-training SynthSeg for SC images is that it would require all the labels in a SC scan (i.e. those of SC, cerebrospinal fluid, vertebrae, bones, brain, etc.)

As it is not feasible to obtain labels for each anatomical region in a SC image, I tried to ease the constraints by only focusing on the segmentations of 4 parts: (i) vertebrae, (ii) discs, (iii) SC, and (iv) cerebrospinal fluid (CSF). The labels for these regions were obtained using the TotalSegmentatorMRI model.

Experiments

The SynthSeg repo contains well-described tutorials for re-training your own models. The following sections will describe how I have tweaked SynthSeg for SC data.

Defining the labels for generating synthetic images

There are 4 key elements in generating synthetic images based on labels:

  1. Defining the path to the training label maps: In my case, this corresponds to the folder containing the TotalSeg labels. I only used the labels for T1w and T2w contrasts as the labels were not available for the rest of the spine-generic contrasts.
  2. Label classes for generation: In my case, these correspond to the ids for the background, vertebrae, discs, SC and CSF. Following the preliminary version of TotalSpineSeg model, I used the following ids to define the labels. Note that using the same ids in the tutorial didn't make sense as those labels correspond to the brain (and not the SC)
    gen_labels_cerv_totalseg = np.array([0, 
    # vertebrae (covering C1 to T12)
    41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23,
    # discs (covering C1/C2 to T11/T12)
    224, 223, 222, 221, 220, 219, 218, 217, 216, 215, 214, 213, 212, 211, 210, 209, 208,
    # SC and CSF
    200, 201,
    ])
  3. Defining the output label maps: To keep it simple, I used the same labels as above, meaning that the synthetic images will have all classes that are defined in the label generation array above. (SynthSeg provides an option to disable some labels if you don't want to generate them). Also, to keep it simple, all the vertebrae correspond to one class and all the discs correspond to one class (class id shown below).
  4. Defining the generation classes: This corresponds to the output classes that the segmentation model learns to output during training. Unlike a random collection of ids as shown above, these are typically incremental ids starting from 0 to the total number of classes. In my case, I provided the following 4 classes:
    1: SC
    2: CSF
    3: Vertebrae
    4: Discs

These are essentially the 4 key parameters in the procedure for generating synthetic images from the GMM.

naga-karthik commented 5 months ago

The following are a few examples of synthetic images from the GMM

Example 1 ![ezgif com-animated-gif-maker](https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/assets/53445351/3b45630e-eccd-4dcc-9552-0233e7b9d59e)
Example 2 ![ezgif com-animated-gif-maker-2](https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/assets/53445351/9b691326-4b06-4a20-85d4-cfe330d7ed57)
Example 3 ![ezgif com-animated-gif-maker-3](https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/assets/53445351/9643e96c-3260-46b3-a862-ce1c1c6a70bb)

As expected, the synthetic images look very weird -- which is the point of SynthSeg i.e. the heavy deformation and/or unrealistic images -- which makes it contrast-agnostic.

naga-karthik commented 5 months ago

Training

The training procedure borrows several aspects of the synthetic image generation procedure (described in the 1st comment) and adds the regular training hyperparameters. To keep it close to our contrast-agnostic model for SC, I set the output shape to 64 x 192 x 256. I also set bias_field_std=0.5 as this is the range we chose in our transforms.

When training for 5 epochs initially, the model encountered NaN values in the loss. It seems that loss values were also encountered here when re-training SynthSeg. More debugging to be continued ...

naga-karthik commented 4 months ago

I was able to fix the NaN issue and the first set of results are in. Each following comment describes the experiment along with the results obtained. The generation labels and the segmentation classes remain the same as described in https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/111#issue-2364557862. A total of 209 labeled cervical spine scans were used consisting of labels for the cord, CSF, vertebrae and intervertebral discs obtained from T1w and T2w contrasts. (Note, SynthSeg primarily used the fully-labeled brain scans for T1w contrast)

Experiment 1

The following hyperparameters are defined in this training script, which I have modified to stay close to our contrast-agnostic model for SC segmentation. The hyperparams defined below are the only ones changed for this experiment, rest all are set to their default values.

n_neutral_labels = None   # (setting this to an incorrect value is what I believe resulted in NaN values above
# shape and resolution of the outputs
target_res = None   # SynthSeg by default outputs everything at 1mm iso; since we don't want that, it is set to None  
output_shape = (64, 192, 256)    # setting this to the original shape of the labels (192, 260, 320) resulted in OOM errors;  
                                 # the cropped output shape is also close to what we use in contrast-agnostic (64, 192, 320) i.e. heavy cropping on R-L dim. 

# spatial deformation parameters
flipping = False     
rotation_bounds = 20  
shearing_bounds = .012
bias_field_std = .5

# architecture parameters
activation = 'relu'     # we used relu, so changed it from 'elu' to 'relu'

# training parameters
dice_epochs = 25       # number of training epochs
steps_per_epoch = 1000  # number of iteration per epoch
batchsize = 2

Results

Test sample 1 **sub-barcelona05_T1w** Screenshot 2024-07-03 at 9 51 38 AM
Test sample 2 **sub-amu05_T1w** Screenshot 2024-07-03 at 9 56 18 AM

We see that the predictions are not good as they fail to capture the structure of the cervical spine. None of the labels (i.e. cord, CSF, vertebrae, discs, etc.) are properly segmented.

naga-karthik commented 4 months ago

Experiment 2

Since flipping=True was used by default in SynthSeg, I trained another model this time setting flipping to True and keeping rest of the hyperparams defined in https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/111#issuecomment-2206186379 to be the same.

Test sample 1 **sub-barcelona05_T1w** Screenshot 2024-07-03 at 10 08 41 AM

Surprisingly, the model does not predict anything when flipping=True during training

naga-karthik commented 4 months ago

Experiment 3

To simplify the segmentation problem, I tried to re-train the model with only 2 output classes (i.e. those of the spinal cord and the CSF). Specifically, given all the following labels, the model is trained to output only the SC and CSF classes with label values 200 and 201 (keeping the rest of the hyperparams same as in Experiment 1)

gen_labels_cerv_totalseg = np.array([0, 
    # vertebrae (covering C1 to T12)
    41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23,
    # discs (covering C1/C2 to T11/T12)
    224, 223, 222, 221, 220, 219, 218, 217, 216, 215, 214, 213, 212, 211, 210, 209, 208,
    # SC and CSF --> only output these labels
    200, 201,
])

The model failed to output anything as we get empty predictions for all test images here as well ...

Test sample 1 **sub_brnoUhb04_T1w** Screenshot 2024-07-04 at 11 55 28 AM

Given the string of poor results, I am not sure what could be the actual reason for SynthSeg failing on SC images. One potential issue could be that, at test time, the input T1w images contain both brain and the SC (like any typical T1w SC image). However, the synthetic scans that SynthSeg generates during training only contains the images with 4 classes -- maybe this discrepancy is one issue? But, to fix this, it would mean that we need all the labels (i.e. for those of the brain and those of the SC, which is really impractical)

jcohenadad commented 4 months ago

But, to fix this, it would mean that we need all the labels (i.e. for those of the brain and those of the SC, which is really impractical)

indeed-- let's not go there