qubvel / segmentation_models

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
MIT License
4.73k stars 1.03k forks source link

data leakage issue while training U-Net SM models with ResNet backbone #527

Open Aliktk opened 2 years ago

Aliktk commented 2 years ago

Hello everyone, I am doing a medical image segmentation task where I am using the semantic segmentation method.

When I load my dataset using a data generator and load my model start my model training on low batch_size. Then after a few epochs, my memory is going up and on 14 epochs my memory is exhausted and the kernel dies. I don't know the reason behind this and why this is happing. I share my code below for data loading and model loading please help me where I am doing mistakes.

# Define constants
SEED = 6
BATCH_SIZE_TRAIN = 128
BATCH_SIZE_TEST = 128

IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
IMG_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)

data_dir_train_image = '/home/xylexa/Desktop/train/images/'
array1=os.listdir(data_dir_train_image + 'imgs')
print(len(array1))

data_dir_train_mask = '/home/xylexa/train/masks/'
array2=os.listdir(data_dir_train_mask + 'msk')
print(len(array2))

data_dir_test_image = '/home/xylexa/test/images/'
array3=os.listdir(data_dir_test_image + 'imgs')
print(len(array3))

data_dir_test_mask = '/home/xylexa/test/masks/'
array4=os.listdir(data_dir_test_mask + 'msk')
print(len(array4))
NUM_TRAIN = 12319 
NUM_TEST = 2217

def create_segmentation_generator_train(img_path, msk_path, BATCH_SIZE):
    data_gen_args = dict(rescale=1./255,
 #                     featurewise_center=True,
 #                     featurewise_std_normalization=True,
                     #rotation_range=90,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    zoom_range=0.05
                        )
    datagen = ImageDataGenerator(**data_gen_args)
    #datagen = ImageDataGenerator()
    img_generator = datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='rgb', batch_size=BATCH_SIZE, seed= SEED)
    msk_generator = datagen.flow_from_directory(msk_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed= SEED)
    return zip(img_generator, msk_generator)

# Remember not to perform any image augmentation in the test generator!
def create_segmentation_generator_test(img_path, msk_path, BATCH_SIZE):
    data_gen_args = dict(rescale=1./255)
    datagen = ImageDataGenerator(**data_gen_args)

    img_generator = datagen.flow_from_directory(img_path, target_size=IMG_SIZE, class_mode=None, color_mode='rgb', batch_size=BATCH_SIZE, seed= SEED)
    msk_generator = datagen.flow_from_directory(msk_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed= SEED)
    return zip(img_generator, msk_generator)

train_generator = create_segmentation_generator_train(data_dir_train_image, data_dir_train_mask, BATCH_SIZE_TRAIN)
test_generator = create_segmentation_generator_test(data_dir_test_image, data_dir_test_mask, BATCH_SIZE_TEST)

# Open a strategy scope.
with strategy.scope():
    sm.set_framework('tf.keras')

    EPOCH_STEP_TRAIN = NUM_TRAIN // BATCH_SIZE_TRAIN
    EPOCH_STEP_TEST = NUM_TEST // BATCH_SIZE_TEST

    model = sm.Unet(backbone_name='resnet152',classes= 1, activation='sigmoid', encoder_weights='imagenet', input_shape=(224,224,3),encoder_freeze=True)

    model.compile(
        'Adam',
        loss=sm.losses.DiceLoss(),
        metrics=[sm.metrics.iou_score],
    )

#from segmentation_models.utils import set_trainable
model_checkpoint = ModelCheckpoint('/home/xylexa/Desktop/model_check_16_April_22.h5', monitor='val_loss',verbose=1, save_best_only=True)

model.fit_generator(generator=train_generator, 
                    use_multiprocessing=True,
                    workers=12,
                    steps_per_epoch=EPOCH_STEP_TRAIN, 
                    validation_data=test_generator, 
                    validation_steps=EPOCH_STEP_TEST,
                    epochs=100, callbacks= model_checkpoint)

any help will be appreciated.

attiladoor commented 1 year ago

Have you thought about maybe the fit_generator can be very memory hungry and can load in too much training data. Maybe try to lower the number of workers, but that's just a guess