ellisdg / 3DUnetCNN

Pytorch 3D U-Net Convolution Neural Network (CNN) designed for medical image segmentation
MIT License
1.9k stars 653 forks source link

I don't what the code means,please help me,thank you~ #3

Closed Kaido0 closed 7 years ago

Kaido0 commented 7 years ago

When I run this code according to the closed issue,I understand most of the code, But there are too many path and file need to be put into code,I got some confuse. Then,I run this code,there are something wrong happened,I can't fix it.And I need some help~pls

Traceback (most recent call last):
  File "/home/kaido/workspace/3DUnet/UnetTraining.py", line 226, in <module>
    main(overwrite=False)
  File "/home/kaido/workspace/3DUnet/UnetTraining.py", line 163, in main
    train_model(model, model_file, overwrite=overwrite, iterations=training_iterations)
  File "/home/kaido/workspace/3DUnet/UnetTraining.py", line 191, in train_model
    subjects[dirname.split('_')[-2]] = dirname
IndexError: list index out of range
Kaido0 commented 7 years ago

this is the whole code

# -*- encoding:utf-8-*-

import os
import glob
import pickle
import datetime

import numpy as np

from keras.layers import (Conv3D, AveragePooling3D, MaxPooling3D, Activation, UpSampling3D, merge, Input, Reshape,
                          Permute)
from keras import backend as K
from keras.models import Model, load_model
from keras.optimizers import Adam

import SimpleITK as sitk

# The BRATS dataset also contains T2 scans

pool_size = (2, 2, 2)
image_shape = (144, 240, 240)
n_channels = 2
# n_channels is the number of modalities (T1c, FLAIR(核磁共振反转图像), etc.

input_shape = tuple([n_channels] + list(image_shape))
n_labels = 5
batch_size = 1
n_test_subjects = 40
z_crop = 155 - image_shape[0]
training_iterations = 5

def pickle_dump(item, out_file):
    with open(out_file, "wb") as opened_file:
        pickle.dump(item, opened_file)

def pickle_load(in_file):
    with open(in_file, "rb") as opened_file:
        return pickle.load(opened_file)

K.set_image_dim_ordering('th')
smooth = 1.

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

def unet_model():
    inputs = Input(input_shape)
    conv1 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(inputs)
    conv1 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(conv1)
    pool1 = MaxPooling3D(pool_size=pool_size)(conv1)

    conv2 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(pool1)
    conv2 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(conv2)
    pool2 = MaxPooling3D(pool_size=pool_size)(conv2)

    conv3 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(pool2)
    conv3 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(conv3)
    pool3 = MaxPooling3D(pool_size=pool_size)(conv3)

    conv4 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(pool3)
    conv4 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(conv4)
    pool4 = MaxPooling3D(pool_size=pool_size)(conv4)

    conv5 = Conv3D(512, 3, 3, 3, activation='relu', border_mode='same')(pool4)
    conv5 = Conv3D(512, 3, 3, 3, activation='relu', border_mode='same')(conv5)

    up6 = merge([UpSampling3D(size=pool_size)(conv5), conv4], mode='concat', concat_axis=1)
    conv6 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(up6)
    conv6 = Conv3D(256, 3, 3, 3, activation='relu', border_mode='same')(conv6)

    up7 = merge([UpSampling3D(size=pool_size)(conv6), conv3], mode='concat', concat_axis=1)
    conv7 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(up7)
    conv7 = Conv3D(128, 3, 3, 3, activation='relu', border_mode='same')(conv7)

    up8 = merge([UpSampling3D(size=pool_size)(conv7), conv2], mode='concat', concat_axis=1)
    conv8 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(up8)
    conv8 = Conv3D(64, 3, 3, 3, activation='relu', border_mode='same')(conv8)

    up9 = merge([UpSampling3D(size=pool_size)(conv8), conv1], mode='concat', concat_axis=1)
    conv9 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(up9)
    conv9 = Conv3D(32, 3, 3, 3, activation='relu', border_mode='same')(conv9)

    conv10 = Conv3D(n_labels, 1, 1, 1)(conv9)
    act = Activation('sigmoid')(conv10)

    model = Model(input=inputs, output=act)

    model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])

    return model

def train_batch(batch, model):
    x_train = batch[:,:2]
    y_train = get_truth(batch)
    del(batch)
    print(model.train_on_batch(x_train, y_train))
    del(x_train)
    del(y_train)

def read_subject_folder(folder):
    flair_image = sitk.ReadImage(os.path.join(folder, "Flair_subtrMeanDivStd.nii.gz"))
   # t1_image = sitk.ReadImage(os.path.join(folder, "T1.nii.gz"))
    t1c_image = sitk.ReadImage(os.path.join(folder, "T1c_subtrMeanDivStd.nii.gz"))
    truth_image = sitk.ReadImage(os.path.join(folder, "OTMultiClass.nii.gz"))
    #background_image = sitk.ReadImage(os.path.join(folder, "background.nii.gz"))
    return np.array([#sitk.GetArrayFromImage(t1_image),
                     sitk.GetArrayFromImage(t1c_image), 
                     sitk.GetArrayFromImage(flair_image),
                     sitk.GetArrayFromImage(truth_image)])
                     #sitk.GetArrayFromImage(background_image)

# def crop_data(data, background_channel=4):
#     if np.all(data[background_channel, :z_crop] == 1):
#         return data[:, z_crop:]
#     elif np.all(data[background_channel, data.shape[1] - z_crop:] == 1):
#         return data[:, :data.shape[1] - z_crop]
#     else:
#         upper = z_crop/2
#         lower = z_crop - upper
#         return data[:, lower:data.shape[1] - upper]
def crop_data(data,z_crop):
    return data[:,z_crop:]

def get_truth(batch, truth_channel=2):
    truth = np.array(batch)[:, truth_channel]
    batch_list = []
    for sample_number in range(truth.shape[0]):
        sample_list = []
        for label in range(n_labels):
            array = np.zeros(truth[sample_number].shape)
            array[truth[sample_number] == label] = 1
            sample_list.append(array)
        batch_list.append(sample_list)
    return np.array(batch_list)

def get_subject_id(subject_dir):
    return subject_dir.split("_")[-2]

def main(overwrite=False):
    model_file = os.path.abspath("3d_unet_model.h5") #返回path规范化的绝对路径
    if not overwrite and os.path.exists(model_file):
        model = load_model(model_file, custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
    else:
        model = unet_model()
    train_model(model, model_file, overwrite=overwrite, iterations=training_iterations)

def get_subject_dirs():
    return glob.glob("/home/kaido/workspace/3DUnet/test/*")

def train_model(model, model_file, overwrite=False, iterations=1):
    for i in range(iterations):
        processed_list_file = os.path.abspath("processed_subjects.pkl")
        if overwrite or not os.path.exists(processed_list_file) or i > 0:
            processed_list = []
        else:
            processed_list = pickle_load(processed_list_file)

        subject_dirs = get_subject_dirs()

        testing_ids_file = os.path.abspath("testing_ids.pkl")

        if os.path.exists(testing_ids_file) and not overwrite:
            testing_ids = pickle_load(testing_ids_file)
            if len(testing_ids) > n_test_subjects:
                testing_ids = testing_ids[:n_test_subjects]
                pickle_dump(testing_ids, testing_ids_file)
        else:
            # reomove duplicate sessions
            subjects = dict()
            for dirname in subject_dirs:
                subjects[dirname.split('_')[-2]] = dirname

            subject_ids = subjects.keys()
            np.random.shuffle(subject_ids)
            testing_ids = subject_ids[:n_test_subjects]
            pickle_dump(testing_ids, testing_ids_file)

        batch = []
        for subject_dir in subject_dirs:

            subject_id = get_subject_id(subject_dir)
            if subject_id in testing_ids or subject_id in processed_list:
                continue

            processed_list.append("Flair_subtrMeanDivStd.nii.gz")
            processed_list.append("T1c_subtrMeanDivStd.nii.gz")
            processed_list.append("OTMultiClass.nii.gz")

            batch.append(crop_data(read_subject_folder('/home/kaido/workspace/3DUnet/train')))
            if len(batch) >= batch_size:
                train_batch(np.array(batch), model)
                del(batch)
                batch = []
                print("Saving: " + model_file)
                pickle_dump(processed_list, processed_list_file)
                model.save(model_file)

        if batch:
            train_batch(np.array(batch), model)
            del(batch)
            print("Saving: " + model_file)
            pickle_dump(processed_list, processed_list_file)
            model.save(model_file)

if __name__ == "__main__":
    main(overwrite=False)
ellisdg commented 7 years ago

Hi @Kaido0, the error is due to a difference in file names. I named my files a certain way so that I could get the subject_id from the file name. Since your files are not named the same way, that is why you are getting the error.

Have you looked into making a custom data generator for your data? That's what I would suggest using to train the model. I'm working on making a data generator and adding to this repository, but I'm not done with it yet. You can take a look at that code and maybe it will lead you in the right direction.

ellisdg commented 7 years ago

@Kaido0 checkout the latest code and #5 as I have updated the code a lot so that it is better able to handle input data.