qubvel / segmentation_models

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
MIT License
4.76k stars 1.03k forks source link

Training convergence #350

Open asmagen opened 4 years ago

asmagen commented 4 years ago

My model training appears to be poor in covering to decreased validation loss, while the validation accuracy gets to 0.75-0.8 after just a couple of epochs, and then just bouncing around that range without improving.

Screen Shot 2020-06-05 at 10 34 27 PM

image Is that reasonable? Assuming it's not, what could explain that? Patches not capturing the segmentation classes appropriately or having poor resolution (I'm using 1024 pixel patches which are then scaled to 320 pixel patches as the required input size in the multi-class segmentation tutorial), or alternatively having the architecture/loss defined not well?

JordanMakesMaps commented 4 years ago

Are you reducing the learning rate after each epoch in which the validation loss does not decrease?

asmagen commented 4 years ago

I think it does by the following:

callbacks = [
    keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
    keras.callbacks.ReduceLROnPlateau(),
]

Is that correct? If not, how can I do that?

JordanMakesMaps commented 4 years ago

By default, ReduceLRonPlateau starts reducing after 10 epochs of no continued decrease. I would change the default parameters and also change verbose to 1 so you can see when and what it is being reduced to. This might help.

What percentage of the data is going to training and validation, also, are you using augmentation?

asmagen commented 4 years ago

Are there any parameters you recommend starting with or parameter ranges to experiment with?

I'm current using the following and it doesn't converge well.

BATCH_SIZE = 5 LR = 1e-3 EPOCHS = 100 keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-7, verbose = 1)

Thanks

JordanMakesMaps commented 4 years ago

The reason for the OOM error from the other post is likely due to your batch size being 5 instead of 1. Also, there is some merit in using batch sizes as powers of 2.

How many images do you have to use?

asmagen commented 4 years ago

Hi @JordanMakesMaps and others

Thanks for your input. I'll try to consolidate the multiple discussions to a common core. I'm starting to understand that the hyper-parameter space (if I used that to refer to all of the architecture/mode/batch and learning parameters) is huge and I don't have a clear idea of how to distill them to a specific model without trying too many options, so I'm hoping the community knowledge could help me optimize that. I annotated questions in bold since the post is pretty long.

I'm trying to train a segmentation U-Net to separate 7 categories in histopathology slides -- some are bigger general regions (Tumor, normal, slide background) and some are smaller architecture or artifacts (immune aggregate, bile duct, shadow or tissue fold). In this small patch example we can see in green the stroma, purple background (hole) and most importantly the immune aggregates in yellow, which are rare and the upmost priority for me to segment accurately, although the other categories should be reasonably accurate as well. The rest of the background is normal tissue which is very abundant in that slide. This is not a representative area because the aggregates are much more abundant than the usual.

Screen Shot 2020-06-06 at 7 10 51 PM

I have semi-dense annotations for 10 whole slides in 5x resolution (about 5k x 5k pixels), for example (just a screenshot so very low res):

Screen Shot 2020-06-06 at 6 53 39 PM

So out of these annotations I create a categorical mask comprising of the following categories: ['Normal','Tumor','Stroma','Bile Ducts','Immune Aggregate','Tissue Fold','Background','Unannotated'] where unannotated captures as expected all of the pixels that didn't map to any of the categories. Unfortunately since the slides are very big we can't annotate all of the regions.

I tile the slides and masks into 512 x 512 patches with 256 pixel overlap. I selected 512 because it captures smaller architecture as a whole, such as the immune aggregates. Would it be better to capture smaller or bigger contexts if I want to identify these structures accurately?

Screen Shot 2020-06-06 at 6 56 20 PM

Since the region size is highly imbalanced, I select only about 100-500 tiles per slide which capture more of the lowly abundant regions, as well as filtering those having less than 25% of the patch assigned to the background or unannotated categories. But nonetheless still most of the patches correspond to the huge areas like tumor/normal

Screen Shot 2020-06-06 at 7 03 08 PM

I use the patches coming from 6 whole slides for training, where 60% are assigned randomly to the train set and the rest to validation set. All the patches from the other slides (from other patients as a strong independent validation) are assigned to the independent test.

I apply augmentations:

A.HorizontalFlip(p=0.5),

        A.ShiftScaleRotate(scale_limit=0.3, rotate_limit=0.2, shift_limit=0.1, p=1, border_mode=0),

        A.IAAAdditiveGaussianNoise(p=0.2),

        A.IAAPerspective(p=0.5),

        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightness(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        A.OneOf(
            [
                A.IAASharpen(p=1),
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        A.OneOf(
            [
                A.RandomContrast(limit=0.1,p=0.5),
                A.HueSaturationValue(p=0.5),
            ],
            p=0.9,
        ),
        A.Lambda(mask=round_clip_0_1)
Screen Shot 2020-06-06 at 7 04 18 PM

And then comes the tricky model definition. I started with:

Is the utilization of the TPU good enough with batch size of 8 or is it intended to be used with a significantly higher batch size? and do you recommend other loss function to the task at hand? and are the weights I'm defining reasonable?

So overall, for clarity, I have:

BACKBONE = 'resnet34'
BATCH_SIZE = 8
CLASSES = ['Normal','Tumor','Stroma','Bile Ducts','Lymphoid Aggregate','Tissue Fold','Background']
LR = 1e-2
EPOCHS = 50

preprocess_input = sm.get_preprocessing(BACKBONE)

# define network parameters
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation
activation = 'sigmoid' if n_classes == 1 else 'softmax'

#create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation, encoder_weights='imagenet',input_shape = (None, None, 3))

weights = [0.5,0.5,1,1,1,1,0.5,0]

# define optimizer
optim = keras.optimizers.Adam(LR)

dice_loss = sm.losses.DiceLoss(class_weights=np.array(weights)) 
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (0.5 * focal_loss)

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

# compile keras model with defined optimozer, loss and metrics
model.compile(optim, total_loss, metrics)

# Dataset for train images
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    classes=CLASSES, 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocess_input),
)

# Dataset for validation images
valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    classes=CLASSES, 
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocess_input),
)

train_dataloader = Dataloder(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = Dataloder(valid_dataset, batch_size=1, shuffle=False)

# define callbacks for learning rate scheduling and best checkpoints saving
callbacks = [
    keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-7, verbose=True),
]

# print model layers
model.summary()

It results in a super massive architecture:

Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
data (InputLayer)               (None, None, None, 3 0                                            
__________________________________________________________________________________________________
bn_data (BatchNormalization)    (None, None, None, 3 9           data[0][0]                       
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, None, None, 3 0           bn_data[0][0]                    
__________________________________________________________________________________________________
conv0 (Conv2D)                  (None, None, None, 6 9408        zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
bn0 (BatchNormalization)        (None, None, None, 6 256         conv0[0][0]                      
__________________________________________________________________________________________________
relu0 (Activation)              (None, None, None, 6 0           bn0[0][0]                        
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, None, None, 6 0           relu0[0][0]                      
__________________________________________________________________________________________________
pooling0 (MaxPooling2D)         (None, None, None, 6 0           zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
stage1_unit1_bn1 (BatchNormaliz (None, None, None, 6 256         pooling0[0][0]                   
__________________________________________________________________________________________________
stage1_unit1_relu1 (Activation) (None, None, None, 6 0           stage1_unit1_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, None, None, 6 0           stage1_unit1_relu1[0][0]         
__________________________________________________________________________________________________
stage1_unit1_conv1 (Conv2D)     (None, None, None, 6 36864       zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
stage1_unit1_bn2 (BatchNormaliz (None, None, None, 6 256         stage1_unit1_conv1[0][0]         
__________________________________________________________________________________________________
stage1_unit1_relu2 (Activation) (None, None, None, 6 0           stage1_unit1_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, None, None, 6 0           stage1_unit1_relu2[0][0]         
__________________________________________________________________________________________________
stage1_unit1_conv2 (Conv2D)     (None, None, None, 6 36864       zero_padding2d_4[0][0]           
__________________________________________________________________________________________________
stage1_unit1_sc (Conv2D)        (None, None, None, 6 4096        stage1_unit1_relu1[0][0]         
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, None, 6 0           stage1_unit1_conv2[0][0]         
                                                                 stage1_unit1_sc[0][0]            
__________________________________________________________________________________________________
stage1_unit2_bn1 (BatchNormaliz (None, None, None, 6 256         add_1[0][0]                      
__________________________________________________________________________________________________
stage1_unit2_relu1 (Activation) (None, None, None, 6 0           stage1_unit2_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_5 (ZeroPadding2D (None, None, None, 6 0           stage1_unit2_relu1[0][0]         
__________________________________________________________________________________________________
stage1_unit2_conv1 (Conv2D)     (None, None, None, 6 36864       zero_padding2d_5[0][0]           
__________________________________________________________________________________________________
stage1_unit2_bn2 (BatchNormaliz (None, None, None, 6 256         stage1_unit2_conv1[0][0]         
__________________________________________________________________________________________________
stage1_unit2_relu2 (Activation) (None, None, None, 6 0           stage1_unit2_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None, None, None, 6 0           stage1_unit2_relu2[0][0]         
__________________________________________________________________________________________________
stage1_unit2_conv2 (Conv2D)     (None, None, None, 6 36864       zero_padding2d_6[0][0]           
__________________________________________________________________________________________________
add_2 (Add)                     (None, None, None, 6 0           stage1_unit2_conv2[0][0]         
                                                                 add_1[0][0]                      
__________________________________________________________________________________________________
stage1_unit3_bn1 (BatchNormaliz (None, None, None, 6 256         add_2[0][0]                      
__________________________________________________________________________________________________
stage1_unit3_relu1 (Activation) (None, None, None, 6 0           stage1_unit3_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_7 (ZeroPadding2D (None, None, None, 6 0           stage1_unit3_relu1[0][0]         
__________________________________________________________________________________________________
stage1_unit3_conv1 (Conv2D)     (None, None, None, 6 36864       zero_padding2d_7[0][0]           
__________________________________________________________________________________________________
stage1_unit3_bn2 (BatchNormaliz (None, None, None, 6 256         stage1_unit3_conv1[0][0]         
__________________________________________________________________________________________________
stage1_unit3_relu2 (Activation) (None, None, None, 6 0           stage1_unit3_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_8 (ZeroPadding2D (None, None, None, 6 0           stage1_unit3_relu2[0][0]         
__________________________________________________________________________________________________
stage1_unit3_conv2 (Conv2D)     (None, None, None, 6 36864       zero_padding2d_8[0][0]           
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, None, 6 0           stage1_unit3_conv2[0][0]         
                                                                 add_2[0][0]                      
__________________________________________________________________________________________________
stage2_unit1_bn1 (BatchNormaliz (None, None, None, 6 256         add_3[0][0]                      
__________________________________________________________________________________________________
stage2_unit1_relu1 (Activation) (None, None, None, 6 0           stage2_unit1_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_9 (ZeroPadding2D (None, None, None, 6 0           stage2_unit1_relu1[0][0]         
__________________________________________________________________________________________________
stage2_unit1_conv1 (Conv2D)     (None, None, None, 1 73728       zero_padding2d_9[0][0]           
__________________________________________________________________________________________________
stage2_unit1_bn2 (BatchNormaliz (None, None, None, 1 512         stage2_unit1_conv1[0][0]         
__________________________________________________________________________________________________
stage2_unit1_relu2 (Activation) (None, None, None, 1 0           stage2_unit1_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_10 (ZeroPadding2 (None, None, None, 1 0           stage2_unit1_relu2[0][0]         
__________________________________________________________________________________________________
stage2_unit1_conv2 (Conv2D)     (None, None, None, 1 147456      zero_padding2d_10[0][0]          
__________________________________________________________________________________________________
stage2_unit1_sc (Conv2D)        (None, None, None, 1 8192        stage2_unit1_relu1[0][0]         
__________________________________________________________________________________________________
add_4 (Add)                     (None, None, None, 1 0           stage2_unit1_conv2[0][0]         
                                                                 stage2_unit1_sc[0][0]            
__________________________________________________________________________________________________
stage2_unit2_bn1 (BatchNormaliz (None, None, None, 1 512         add_4[0][0]                      
__________________________________________________________________________________________________
stage2_unit2_relu1 (Activation) (None, None, None, 1 0           stage2_unit2_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_11 (ZeroPadding2 (None, None, None, 1 0           stage2_unit2_relu1[0][0]         
__________________________________________________________________________________________________
stage2_unit2_conv1 (Conv2D)     (None, None, None, 1 147456      zero_padding2d_11[0][0]          
__________________________________________________________________________________________________
stage2_unit2_bn2 (BatchNormaliz (None, None, None, 1 512         stage2_unit2_conv1[0][0]         
__________________________________________________________________________________________________
stage2_unit2_relu2 (Activation) (None, None, None, 1 0           stage2_unit2_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_12 (ZeroPadding2 (None, None, None, 1 0           stage2_unit2_relu2[0][0]         
__________________________________________________________________________________________________
stage2_unit2_conv2 (Conv2D)     (None, None, None, 1 147456      zero_padding2d_12[0][0]          
__________________________________________________________________________________________________
add_5 (Add)                     (None, None, None, 1 0           stage2_unit2_conv2[0][0]         
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
stage2_unit3_bn1 (BatchNormaliz (None, None, None, 1 512         add_5[0][0]                      
__________________________________________________________________________________________________
stage2_unit3_relu1 (Activation) (None, None, None, 1 0           stage2_unit3_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_13 (ZeroPadding2 (None, None, None, 1 0           stage2_unit3_relu1[0][0]         
__________________________________________________________________________________________________
stage2_unit3_conv1 (Conv2D)     (None, None, None, 1 147456      zero_padding2d_13[0][0]          
__________________________________________________________________________________________________
stage2_unit3_bn2 (BatchNormaliz (None, None, None, 1 512         stage2_unit3_conv1[0][0]         
__________________________________________________________________________________________________
stage2_unit3_relu2 (Activation) (None, None, None, 1 0           stage2_unit3_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_14 (ZeroPadding2 (None, None, None, 1 0           stage2_unit3_relu2[0][0]         
__________________________________________________________________________________________________
stage2_unit3_conv2 (Conv2D)     (None, None, None, 1 147456      zero_padding2d_14[0][0]          
__________________________________________________________________________________________________
add_6 (Add)                     (None, None, None, 1 0           stage2_unit3_conv2[0][0]         
                                                                 add_5[0][0]                      
__________________________________________________________________________________________________
stage2_unit4_bn1 (BatchNormaliz (None, None, None, 1 512         add_6[0][0]                      
__________________________________________________________________________________________________
stage2_unit4_relu1 (Activation) (None, None, None, 1 0           stage2_unit4_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_15 (ZeroPadding2 (None, None, None, 1 0           stage2_unit4_relu1[0][0]         
__________________________________________________________________________________________________
stage2_unit4_conv1 (Conv2D)     (None, None, None, 1 147456      zero_padding2d_15[0][0]          
__________________________________________________________________________________________________
stage2_unit4_bn2 (BatchNormaliz (None, None, None, 1 512         stage2_unit4_conv1[0][0]         
__________________________________________________________________________________________________
stage2_unit4_relu2 (Activation) (None, None, None, 1 0           stage2_unit4_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_16 (ZeroPadding2 (None, None, None, 1 0           stage2_unit4_relu2[0][0]         
__________________________________________________________________________________________________
stage2_unit4_conv2 (Conv2D)     (None, None, None, 1 147456      zero_padding2d_16[0][0]          
__________________________________________________________________________________________________
add_7 (Add)                     (None, None, None, 1 0           stage2_unit4_conv2[0][0]         
                                                                 add_6[0][0]                      
__________________________________________________________________________________________________
stage3_unit1_bn1 (BatchNormaliz (None, None, None, 1 512         add_7[0][0]                      
__________________________________________________________________________________________________
stage3_unit1_relu1 (Activation) (None, None, None, 1 0           stage3_unit1_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_17 (ZeroPadding2 (None, None, None, 1 0           stage3_unit1_relu1[0][0]         
__________________________________________________________________________________________________
stage3_unit1_conv1 (Conv2D)     (None, None, None, 2 294912      zero_padding2d_17[0][0]          
__________________________________________________________________________________________________
stage3_unit1_bn2 (BatchNormaliz (None, None, None, 2 1024        stage3_unit1_conv1[0][0]         
__________________________________________________________________________________________________
stage3_unit1_relu2 (Activation) (None, None, None, 2 0           stage3_unit1_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_18 (ZeroPadding2 (None, None, None, 2 0           stage3_unit1_relu2[0][0]         
__________________________________________________________________________________________________
stage3_unit1_conv2 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_18[0][0]          
__________________________________________________________________________________________________
stage3_unit1_sc (Conv2D)        (None, None, None, 2 32768       stage3_unit1_relu1[0][0]         
__________________________________________________________________________________________________
add_8 (Add)                     (None, None, None, 2 0           stage3_unit1_conv2[0][0]         
                                                                 stage3_unit1_sc[0][0]            
__________________________________________________________________________________________________
stage3_unit2_bn1 (BatchNormaliz (None, None, None, 2 1024        add_8[0][0]                      
__________________________________________________________________________________________________
stage3_unit2_relu1 (Activation) (None, None, None, 2 0           stage3_unit2_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_19 (ZeroPadding2 (None, None, None, 2 0           stage3_unit2_relu1[0][0]         
__________________________________________________________________________________________________
stage3_unit2_conv1 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_19[0][0]          
__________________________________________________________________________________________________
stage3_unit2_bn2 (BatchNormaliz (None, None, None, 2 1024        stage3_unit2_conv1[0][0]         
__________________________________________________________________________________________________
stage3_unit2_relu2 (Activation) (None, None, None, 2 0           stage3_unit2_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_20 (ZeroPadding2 (None, None, None, 2 0           stage3_unit2_relu2[0][0]         
__________________________________________________________________________________________________
stage3_unit2_conv2 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_20[0][0]          
__________________________________________________________________________________________________
add_9 (Add)                     (None, None, None, 2 0           stage3_unit2_conv2[0][0]         
                                                                 add_8[0][0]                      
__________________________________________________________________________________________________
stage3_unit3_bn1 (BatchNormaliz (None, None, None, 2 1024        add_9[0][0]                      
__________________________________________________________________________________________________
stage3_unit3_relu1 (Activation) (None, None, None, 2 0           stage3_unit3_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_21 (ZeroPadding2 (None, None, None, 2 0           stage3_unit3_relu1[0][0]         
__________________________________________________________________________________________________
stage3_unit3_conv1 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_21[0][0]          
__________________________________________________________________________________________________
stage3_unit3_bn2 (BatchNormaliz (None, None, None, 2 1024        stage3_unit3_conv1[0][0]         
__________________________________________________________________________________________________
stage3_unit3_relu2 (Activation) (None, None, None, 2 0           stage3_unit3_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_22 (ZeroPadding2 (None, None, None, 2 0           stage3_unit3_relu2[0][0]         
__________________________________________________________________________________________________
stage3_unit3_conv2 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_22[0][0]          
__________________________________________________________________________________________________
add_10 (Add)                    (None, None, None, 2 0           stage3_unit3_conv2[0][0]         
                                                                 add_9[0][0]                      
__________________________________________________________________________________________________
stage3_unit4_bn1 (BatchNormaliz (None, None, None, 2 1024        add_10[0][0]                     
__________________________________________________________________________________________________
stage3_unit4_relu1 (Activation) (None, None, None, 2 0           stage3_unit4_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_23 (ZeroPadding2 (None, None, None, 2 0           stage3_unit4_relu1[0][0]         
__________________________________________________________________________________________________
stage3_unit4_conv1 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_23[0][0]          
__________________________________________________________________________________________________
stage3_unit4_bn2 (BatchNormaliz (None, None, None, 2 1024        stage3_unit4_conv1[0][0]         
__________________________________________________________________________________________________
stage3_unit4_relu2 (Activation) (None, None, None, 2 0           stage3_unit4_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_24 (ZeroPadding2 (None, None, None, 2 0           stage3_unit4_relu2[0][0]         
__________________________________________________________________________________________________
stage3_unit4_conv2 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_24[0][0]          
__________________________________________________________________________________________________
add_11 (Add)                    (None, None, None, 2 0           stage3_unit4_conv2[0][0]         
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
stage3_unit5_bn1 (BatchNormaliz (None, None, None, 2 1024        add_11[0][0]                     
__________________________________________________________________________________________________
stage3_unit5_relu1 (Activation) (None, None, None, 2 0           stage3_unit5_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_25 (ZeroPadding2 (None, None, None, 2 0           stage3_unit5_relu1[0][0]         
__________________________________________________________________________________________________
stage3_unit5_conv1 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_25[0][0]          
__________________________________________________________________________________________________
stage3_unit5_bn2 (BatchNormaliz (None, None, None, 2 1024        stage3_unit5_conv1[0][0]         
__________________________________________________________________________________________________
stage3_unit5_relu2 (Activation) (None, None, None, 2 0           stage3_unit5_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_26 (ZeroPadding2 (None, None, None, 2 0           stage3_unit5_relu2[0][0]         
__________________________________________________________________________________________________
stage3_unit5_conv2 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_26[0][0]          
__________________________________________________________________________________________________
add_12 (Add)                    (None, None, None, 2 0           stage3_unit5_conv2[0][0]         
                                                                 add_11[0][0]                     
__________________________________________________________________________________________________
stage3_unit6_bn1 (BatchNormaliz (None, None, None, 2 1024        add_12[0][0]                     
__________________________________________________________________________________________________
stage3_unit6_relu1 (Activation) (None, None, None, 2 0           stage3_unit6_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_27 (ZeroPadding2 (None, None, None, 2 0           stage3_unit6_relu1[0][0]         
__________________________________________________________________________________________________
stage3_unit6_conv1 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_27[0][0]          
__________________________________________________________________________________________________
stage3_unit6_bn2 (BatchNormaliz (None, None, None, 2 1024        stage3_unit6_conv1[0][0]         
__________________________________________________________________________________________________
stage3_unit6_relu2 (Activation) (None, None, None, 2 0           stage3_unit6_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_28 (ZeroPadding2 (None, None, None, 2 0           stage3_unit6_relu2[0][0]         
__________________________________________________________________________________________________
stage3_unit6_conv2 (Conv2D)     (None, None, None, 2 589824      zero_padding2d_28[0][0]          
__________________________________________________________________________________________________
add_13 (Add)                    (None, None, None, 2 0           stage3_unit6_conv2[0][0]         
                                                                 add_12[0][0]                     
__________________________________________________________________________________________________
stage4_unit1_bn1 (BatchNormaliz (None, None, None, 2 1024        add_13[0][0]                     
__________________________________________________________________________________________________
stage4_unit1_relu1 (Activation) (None, None, None, 2 0           stage4_unit1_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_29 (ZeroPadding2 (None, None, None, 2 0           stage4_unit1_relu1[0][0]         
__________________________________________________________________________________________________
stage4_unit1_conv1 (Conv2D)     (None, None, None, 5 1179648     zero_padding2d_29[0][0]          
__________________________________________________________________________________________________
stage4_unit1_bn2 (BatchNormaliz (None, None, None, 5 2048        stage4_unit1_conv1[0][0]         
__________________________________________________________________________________________________
stage4_unit1_relu2 (Activation) (None, None, None, 5 0           stage4_unit1_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_30 (ZeroPadding2 (None, None, None, 5 0           stage4_unit1_relu2[0][0]         
__________________________________________________________________________________________________
stage4_unit1_conv2 (Conv2D)     (None, None, None, 5 2359296     zero_padding2d_30[0][0]          
__________________________________________________________________________________________________
stage4_unit1_sc (Conv2D)        (None, None, None, 5 131072      stage4_unit1_relu1[0][0]         
__________________________________________________________________________________________________
add_14 (Add)                    (None, None, None, 5 0           stage4_unit1_conv2[0][0]         
                                                                 stage4_unit1_sc[0][0]            
__________________________________________________________________________________________________
stage4_unit2_bn1 (BatchNormaliz (None, None, None, 5 2048        add_14[0][0]                     
__________________________________________________________________________________________________
stage4_unit2_relu1 (Activation) (None, None, None, 5 0           stage4_unit2_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_31 (ZeroPadding2 (None, None, None, 5 0           stage4_unit2_relu1[0][0]         
__________________________________________________________________________________________________
stage4_unit2_conv1 (Conv2D)     (None, None, None, 5 2359296     zero_padding2d_31[0][0]          
__________________________________________________________________________________________________
stage4_unit2_bn2 (BatchNormaliz (None, None, None, 5 2048        stage4_unit2_conv1[0][0]         
__________________________________________________________________________________________________
stage4_unit2_relu2 (Activation) (None, None, None, 5 0           stage4_unit2_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_32 (ZeroPadding2 (None, None, None, 5 0           stage4_unit2_relu2[0][0]         
__________________________________________________________________________________________________
stage4_unit2_conv2 (Conv2D)     (None, None, None, 5 2359296     zero_padding2d_32[0][0]          
__________________________________________________________________________________________________
add_15 (Add)                    (None, None, None, 5 0           stage4_unit2_conv2[0][0]         
                                                                 add_14[0][0]                     
__________________________________________________________________________________________________
stage4_unit3_bn1 (BatchNormaliz (None, None, None, 5 2048        add_15[0][0]                     
__________________________________________________________________________________________________
stage4_unit3_relu1 (Activation) (None, None, None, 5 0           stage4_unit3_bn1[0][0]           
__________________________________________________________________________________________________
zero_padding2d_33 (ZeroPadding2 (None, None, None, 5 0           stage4_unit3_relu1[0][0]         
__________________________________________________________________________________________________
stage4_unit3_conv1 (Conv2D)     (None, None, None, 5 2359296     zero_padding2d_33[0][0]          
__________________________________________________________________________________________________
stage4_unit3_bn2 (BatchNormaliz (None, None, None, 5 2048        stage4_unit3_conv1[0][0]         
__________________________________________________________________________________________________
stage4_unit3_relu2 (Activation) (None, None, None, 5 0           stage4_unit3_bn2[0][0]           
__________________________________________________________________________________________________
zero_padding2d_34 (ZeroPadding2 (None, None, None, 5 0           stage4_unit3_relu2[0][0]         
__________________________________________________________________________________________________
stage4_unit3_conv2 (Conv2D)     (None, None, None, 5 2359296     zero_padding2d_34[0][0]          
__________________________________________________________________________________________________
add_16 (Add)                    (None, None, None, 5 0           stage4_unit3_conv2[0][0]         
                                                                 add_15[0][0]                     
__________________________________________________________________________________________________
bn1 (BatchNormalization)        (None, None, None, 5 2048        add_16[0][0]                     
__________________________________________________________________________________________________
relu1 (Activation)              (None, None, None, 5 0           bn1[0][0]                        
__________________________________________________________________________________________________
decoder_stage0_upsampling (UpSa (None, None, None, 5 0           relu1[0][0]                      
__________________________________________________________________________________________________
decoder_stage0_concat (Concaten (None, None, None, 7 0           decoder_stage0_upsampling[0][0]  
                                                                 stage4_unit1_relu1[0][0]         
__________________________________________________________________________________________________
decoder_stage0a_conv (Conv2D)   (None, None, None, 2 1769472     decoder_stage0_concat[0][0]      
__________________________________________________________________________________________________
decoder_stage0a_bn (BatchNormal (None, None, None, 2 1024        decoder_stage0a_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage0a_relu (Activatio (None, None, None, 2 0           decoder_stage0a_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage0b_conv (Conv2D)   (None, None, None, 2 589824      decoder_stage0a_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage0b_bn (BatchNormal (None, None, None, 2 1024        decoder_stage0b_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage0b_relu (Activatio (None, None, None, 2 0           decoder_stage0b_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage1_upsampling (UpSa (None, None, None, 2 0           decoder_stage0b_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage1_concat (Concaten (None, None, None, 3 0           decoder_stage1_upsampling[0][0]  
                                                                 stage3_unit1_relu1[0][0]         
__________________________________________________________________________________________________
decoder_stage1a_conv (Conv2D)   (None, None, None, 1 442368      decoder_stage1_concat[0][0]      
__________________________________________________________________________________________________
decoder_stage1a_bn (BatchNormal (None, None, None, 1 512         decoder_stage1a_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage1a_relu (Activatio (None, None, None, 1 0           decoder_stage1a_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage1b_conv (Conv2D)   (None, None, None, 1 147456      decoder_stage1a_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage1b_bn (BatchNormal (None, None, None, 1 512         decoder_stage1b_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage1b_relu (Activatio (None, None, None, 1 0           decoder_stage1b_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage2_upsampling (UpSa (None, None, None, 1 0           decoder_stage1b_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage2_concat (Concaten (None, None, None, 1 0           decoder_stage2_upsampling[0][0]  
                                                                 stage2_unit1_relu1[0][0]         
__________________________________________________________________________________________________
decoder_stage2a_conv (Conv2D)   (None, None, None, 6 110592      decoder_stage2_concat[0][0]      
__________________________________________________________________________________________________
decoder_stage2a_bn (BatchNormal (None, None, None, 6 256         decoder_stage2a_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage2a_relu (Activatio (None, None, None, 6 0           decoder_stage2a_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage2b_conv (Conv2D)   (None, None, None, 6 36864       decoder_stage2a_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage2b_bn (BatchNormal (None, None, None, 6 256         decoder_stage2b_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage2b_relu (Activatio (None, None, None, 6 0           decoder_stage2b_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage3_upsampling (UpSa (None, None, None, 6 0           decoder_stage2b_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage3_concat (Concaten (None, None, None, 1 0           decoder_stage3_upsampling[0][0]  
                                                                 relu0[0][0]                      
__________________________________________________________________________________________________
decoder_stage3a_conv (Conv2D)   (None, None, None, 3 36864       decoder_stage3_concat[0][0]      
__________________________________________________________________________________________________
decoder_stage3a_bn (BatchNormal (None, None, None, 3 128         decoder_stage3a_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage3a_relu (Activatio (None, None, None, 3 0           decoder_stage3a_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage3b_conv (Conv2D)   (None, None, None, 3 9216        decoder_stage3a_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage3b_bn (BatchNormal (None, None, None, 3 128         decoder_stage3b_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage3b_relu (Activatio (None, None, None, 3 0           decoder_stage3b_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage4_upsampling (UpSa (None, None, None, 3 0           decoder_stage3b_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage4a_conv (Conv2D)   (None, None, None, 1 4608        decoder_stage4_upsampling[0][0]  
__________________________________________________________________________________________________
decoder_stage4a_bn (BatchNormal (None, None, None, 1 64          decoder_stage4a_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage4a_relu (Activatio (None, None, None, 1 0           decoder_stage4a_bn[0][0]         
__________________________________________________________________________________________________
decoder_stage4b_conv (Conv2D)   (None, None, None, 1 2304        decoder_stage4a_relu[0][0]       
__________________________________________________________________________________________________
decoder_stage4b_bn (BatchNormal (None, None, None, 1 64          decoder_stage4b_conv[0][0]       
__________________________________________________________________________________________________
decoder_stage4b_relu (Activatio (None, None, None, 1 0           decoder_stage4b_bn[0][0]         
__________________________________________________________________________________________________
final_conv (Conv2D)             (None, None, None, 8 1160        decoder_stage4b_relu[0][0]       
__________________________________________________________________________________________________
softmax (Activation)            (None, None, None, 8 0           final_conv[0][0]                 
==================================================================================================
Total params: 24,457,169
Trainable params: 24,439,819
Non-trainable params: 17,350
__________________________________________________________________________________________________

Does it make sense to have so many layers or did I do something wrong? And is it reasonable to have such a huge parameter space? (24,439,819)

I run the fitting with:

# train model
history = model.fit_generator(
    train_dataloader, 
    steps_per_epoch=len(train_dataloader), 
    epochs=EPOCHS, 
    callbacks=callbacks, 
    validation_data=valid_dataloader, 
    validation_steps=len(valid_dataloader),
)

And each epoch takes a lot of time (12 minutes), even with the Google Colab TPU, and the loss seems to be decreasing marginally while the validation f1-score is pretty low: 0.5-0.6

Epoch 1/50
103/103 [==============================] - 766s 7s/step - loss: 0.9286 - iou_score: 0.3623 - f1-score: 0.3974 - val_loss: 1.2105 - val_iou_score: 0.5753 - val_f1-score: 0.5810
Epoch 2/50
103/103 [==============================] - 755s 7s/step - loss: 0.9157 - iou_score: 0.3288 - f1-score: 0.3772 - val_loss: 0.9429 - val_iou_score: 0.4850 - val_f1-score: 0.4964
Epoch 3/50
103/103 [==============================] - 757s 7s/step - loss: 0.9145 - iou_score: 0.3509 - f1-score: 0.3996 - val_loss: 0.9518 - val_iou_score: 0.6084 - val_f1-score: 0.6173
Epoch 4/50
103/103 [==============================] - 756s 7s/step - loss: 0.9070 - iou_score: 0.3628 - f1-score: 0.4145 - val_loss: 1.0308 - val_iou_score: 0.5224 - val_f1-score: 0.5399
Epoch 5/50
103/103 [==============================] - 755s 7s/step - loss: 0.9070 - iou_score: 0.3589 - f1-score: 0.4144 - val_loss: 1.0440 - val_iou_score: 0.6339 - val_f1-score: 0.6431
Epoch 6/50
103/103 [==============================] - 757s 7s/step - loss: 0.9059 - iou_score: 0.3637 - f1-score: 0.4166 - val_loss: 0.9821 - val_iou_score: 0.4707 - val_f1-score: 0.4850
Epoch 7/50
103/103 [==============================] - 757s 7s/step - loss: 0.9052 - iou_score: 0.3452 - f1-score: 0.4023 - val_loss: 0.9419 - val_iou_score: 0.4677 - val_f1-score: 0.4753
Epoch 8/50
103/103 [==============================] - 757s 7s/step - loss: 0.8999 - iou_score: 0.3790 - f1-score: 0.4356 - val_loss: 1.0377 - val_iou_score: 0.6374 - val_f1-score: 0.6427
Epoch 9/50
103/103 [==============================] - 758s 7s/step - loss: 0.9047 - iou_score: 0.3608 - f1-score: 0.4151 - val_loss: 0.9414 - val_iou_score: 0.6363 - val_f1-score: 0.6466
Epoch 10/50
103/103 [==============================] - 758s 7s/step - loss: 0.9003 - iou_score: 0.3732 - f1-score: 0.4312 - val_loss: 1.1248 - val_iou_score: 0.5027 - val_f1-score: 0.5088
Epoch 11/50
103/103 [==============================] - 758s 7s/step - loss: 0.8949 - iou_score: 0.3774 - f1-score: 0.4345 - val_loss: 1.0416 - val_iou_score: 0.5164 - val_f1-score: 0.5217
Epoch 12/50
103/103 [==============================] - 758s 7s/step - loss: 0.8923 - iou_score: 0.3724 - f1-score: 0.4294 - val_loss: 0.9391 - val_iou_score: 0.6727 - val_f1-score: 0.6779
Epoch 13/50
103/103 [==============================] - 758s 7s/step - loss: 0.8873 - iou_score: 0.3840 - f1-score: 0.4432 - val_loss: 0.9393 - val_iou_score: 0.6417 - val_f1-score: 0.6463
Epoch 14/50
103/103 [==============================] - 757s 7s/step - loss: 0.8989 - iou_score: 0.3508 - f1-score: 0.4102 - val_loss: 0.9382 - val_iou_score: 0.6733 - val_f1-score: 0.6786
Epoch 15/50
103/103 [==============================] - 758s 7s/step - loss: 0.8995 - iou_score: 0.3591 - f1-score: 0.4155 - val_loss: 0.9826 - val_iou_score: 0.5870 - val_f1-score: 0.6028
Epoch 16/50
 74/103 [====================>.........] - ETA: 2:54 - loss: 0.8882 - iou_score: 0.3703 - f1-score: 0.4323

Is it reasonably to take so much time with such a resource as GPU or TPU? Is the huge loss, but in the train and test set, indicating a problem with the input data or learning? and in general, what kind of strategy do you suggest to revise these parameters? What ranges to experiment with?

Update regarding TPU vs GPU - apparently it takes each epoch just 110 seconds with one GPU in contrast to about 750 seconds with 8 TPUs, although the latter should be faster.

Update regarding simplified model - I tried training only normal vs tumor vs stroma, with weights 0.5,0.5,1,0 (the last is zero again for background) an the model doesn’t seem to be learning at all: image

Thanks for your help!