frankkramer-lab / MIScnn

A framework for Medical Image Segmentation with Convolutional Neural Networks and Deep Learning
GNU General Public License v3.0
403 stars 116 forks source link

Pretrained model for fine tuning #66

Closed JimHeo closed 3 years ago

JimHeo commented 3 years ago

Hi, this is Jim.

First, thanks for the library.

I found your pretrained COVID model in https://zenodo.org/record/3902293. I'm just trying to train KiTS19 using the pretrained model.

But, because of the output shape mismatch, I can't fine tune the model... In this case, how can I train it?

muellerdo commented 3 years ago

Hey Jim,

thank you for your kind words and your interest in using MIScnn.

Mhm. Due to the utilization of patches (patch_shape=(160, 160, 80)) it should also work with samples consisting of new shapes. Maybe, you have to add padding as subfunction to ensure minimal patch sizes, and you have to definitely use the same clipping, normalization and resizing to get appropriately re-use the pretrained weights of the model.

But, because of the output shape mismatch, I can't fine tune the model...

Could you please provide some error log? It should be working out-of-the-box if you copy&pasted the repository and just updated the input to KITS19. https://github.com/frankkramer-lab/covid19.MIScnn/blob/master/scripts/run_miscnn.py

But interesting idea, since transfer learning is currently quite rare in medical image segmentation. After training for a few epochs it may be interesting to see what happens if you adjust the clipping to a smaller range (kidney and tumor Hounsfield Units).

Cheers, Dominik

JimHeo commented 3 years ago

Sure, the error log:

Traceback (most recent call last):
  File "kits19_train.py", line 85, in <module>
    model.train(training_samples, epochs=500, callbacks=[cb_lr, cb_es, cb_cp])
  File "/home/jim/github/MIScnn/miscnn/neural_network/model.py", line 133, in train
    max_queue_size=self.batch_queue_size)
  File "/home/jim/anaconda3/envs/miscnn/lib/python3.7/site-packages/tensorflow-2.4.1-py3.7-linux-x86_64.egg/tensorflow/python/keras/engine/training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/jim/anaconda3/envs/miscnn/lib/python3.7/site-packages/tensorflow-2.4.1-py3.7-linux-x86_64.egg/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/home/jim/anaconda3/envs/miscnn/lib/python3.7/site-packages/tensorflow-2.4.1-py3.7-linux-x86_64.egg/tensorflow/python/eager/def_function.py", line 888, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/jim/anaconda3/envs/miscnn/lib/python3.7/site-packages/tensorflow-2.4.1-py3.7-linux-x86_64.egg/tensorflow/python/eager/function.py", line 2943, in __call__
    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
  File "/home/jim/anaconda3/envs/miscnn/lib/python3.7/site-packages/tensorflow-2.4.1-py3.7-linux-x86_64.egg/tensorflow/python/eager/function.py", line 1919, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/home/jim/anaconda3/envs/miscnn/lib/python3.7/site-packages/tensorflow-2.4.1-py3.7-linux-x86_64.egg/tensorflow/python/eager/function.py", line 560, in call
    ctx=ctx)
  File "/home/jim/anaconda3/envs/miscnn/lib/python3.7/site-packages/tensorflow-2.4.1-py3.7-linux-x86_64.egg/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  logits and labels must be broadcastable: logits_size=[8192000,4] labels_size=[8192000,3]
         [[node softmax_cross_entropy_with_logits (defined at /home/jim/github/MIScnn/miscnn/neural_network/metrics.py:85) ]] [Op:__inference_train_function_9815]

Function call stack:
train_function

2021-02-04 10:41:57.504605: W tensorflow/core/kernels/data/generator_dataset_op.cc:107] Error occurred when finalizing GeneratorDataset iterator: Failed precondition: Python interpreter state is not initialized. The process may be terminated.

and here is the train code (follwing the given kits19 tutorial ipynb):

import tensorflow as tf
import os
from tensorflow.python.keras.saving.saving_utils import model_metadata
from miscnn.data_loading.interfaces.nifti_io import NIFTI_interface
from miscnn.data_loading.data_io import Data_IO
from miscnn.processing.data_augmentation import Data_Augmentation
from miscnn.processing.subfunctions.normalization import Normalization
from miscnn.processing.subfunctions.clipping import Clipping
from miscnn.processing.subfunctions.resampling import Resampling
from miscnn.processing.preprocessor import Preprocessor
from miscnn.neural_network.model import Neural_Network
from miscnn.neural_network.architecture.unet.standard import Architecture
from miscnn.neural_network.metrics import dice_soft, dice_crossentropy, tversky_loss
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ModelCheckpoint

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Initialize the NIfTI I/O interface and configure the images as one channel (grayscale) and three segmentation classes (background, kidney, tumor)
interface = NIFTI_interface(pattern="case_00[0-9]*", channels=1, classes=3)

# Specify the kits19 data directory
data_path = "/home/jim/large_data/kits19/data/"
# Create the Data I/O object 
data_io = Data_IO(interface, data_path)

sample_list = data_io.get_indiceslist()
sample_list.sort()

# Create and configure the Data Augmentation class
data_aug = Data_Augmentation(cycles=2, scaling=True, rotations=True, elastic_deform=True, mirror=True,
                             brightness=True, contrast=True, gamma=True, gaussian_noise=True)

# Create a pixel value normalization Subfunction through Z-Score 
sf_normalize = Normalization(mode='z-score')
# Create a clipping Subfunction between -79 and 304
sf_clipping = Clipping(min=-79, max=304)
# Create a resampling Subfunction to voxel spacing 3.22 x 1.62 x 1.62
sf_resample = Resampling((3.22, 1.62, 1.62))

# Assemble Subfunction classes into a list
# Be aware that the Subfunctions will be exectued according to the list order!
subfunctions = [sf_resample, sf_clipping, sf_normalize]

# Create and configure the Preprocessor class
pp = Preprocessor(data_io, data_aug=data_aug, batch_size=4, subfunctions=subfunctions, prepare_subfunctions=True, 
                  prepare_batches=False, analysis="patchwise-crop", patch_shape=(80, 160, 160),
                  use_multiprocessing=True)

# Adjust the patch overlap for predictions
pp.patchwise_overlap = (40, 80, 80)

# Create the Neural Network model
unet_standard = Architecture(depth=4, activation="softmax", batch_normalization=True)
model = Neural_Network(preprocessor=pp, architecture=unet_standard, loss=tversky_loss, metrics=[dice_soft, dice_crossentropy], learninig_rate=0.0001)

# Define Callbacks
cb_lr = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=20, verbose=1, mode='min', min_delta=0.0001, cooldown=1, min_lr=0.00001)
cb_es = EarlyStopping(monitor='loss', min_delta=0, patience=150, verbose=1, mode='min')
cb_cp = ModelCheckpoint("models/kits_unet.{epoch:02d}.hdf5", monitor='val_loss', verbose=1, save_freq=90*20)

# Exclude suspious samples from data set
del sample_list[133]
del sample_list[125]
del sample_list[68]
del sample_list[37]
del sample_list[23]
del sample_list[15]

training_samples = sample_list[0:180]
# validation_samples = sample_list[180:204]
model.load("models/COVID19/model.fold_0.best_loss.hdf5")
model.train(training_samples, epochs=500, callbacks=[cb_lr, cb_es, cb_cp])
model.dump("models/kits19_unet.hdf5")

I think the reason is the difference of the class number. the number of class for the COVID19 model is 4, and the number of class for the KiTS19 is 3. So, I just want to update the last layer, but, I don't know how.....

muellerdo commented 3 years ago

Hello @JimHeo,

sorry for the late reply.

You are right, the different class numbers requires to add another classification head. Normally, utilizing pretrained models for a different task are designed to use all layers (and its weights) except the classification head. We have to do the same here.

Theoretically, you should perform the following approach:

You can find some posts about this topic here:

Additionally, if you are using transfer learning I would recommend to set all layers except the classification head as not trainable for ~ 3-5 epochs (using high learning rate). Then fine tune the model for n epochs with all layers trainable (using a small learning rate).

Cheers, Dominik

JimHeo commented 3 years ago

I train the model from scratch now. After that, I will try to run the transfer learning later.

Appreciate your help :)

Best regards, Jim