qubvel / segmentation_models

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
MIT License
4.74k stars 1.03k forks source link

Using a different backbone? #431

Open tonyboston-au opened 3 years ago

tonyboston-au commented 3 years ago

My question is how to load weights from a custom backbone (not from the set already supported by qubvel segmentation_models)?

I have trained ImageNet models such as ResNet50 for image classification using satellite data and would like to use the resulting weights as the ResNet50 encoder within a Unet model for segmentation rather than the default ResNet50 ImageNet weights. Does anyone know if this is possible, and, if so, any clues on how to go about this?

Thanks in advance...

JordanMakesMaps commented 3 years ago

Yes this is definitely possible. If you download imagenet weights from keras for resnet 50 and load them into the encoder portion of the unet model in the repo, it will work fine. That means, if you train your own resnet 50 from scratch, you can load those into the encoder of the unet model. I've done it before, I believe I had to create a custom function to do it, but it wasn't too difficult.

tonyboston-au commented 3 years ago

Thanks Jordan. Existing code looks like this: BACKBONE = 'resnet50' model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=n_classes, activation=activation, input_shape=(None, None, N)) which downloads encoder weights from: https://github.com/qubvel/classification_models/releases/download/0.0.1/resnet50_imagenet_1000_no_top.h5

How do you load weights from a local model .h5 file?

JordanMakesMaps commented 3 years ago

Like I said, I believe you need to create a custom function (though some might already exist on the internet, check stackoverflow), in which you go through each layer of the unet encoder and copy over the weights in the corresponding layers from the other encoder. The process is very similar to turning all of the layers in the encoder to frozen or unfrozen, please note that this is not working code, it's just meant to give you an idea of how one could accomplish what you're asking.

# non-working psuedo-code
unet_model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=n_classes, input_shape=(None, None, N))

# however you want to do this
custom_model = ResNet50(weights = 'custom_weights.h5')

# go through each layer in the unet_model,
# find the layers that have the same names as
# the layers in the custom ResNet50 model,
# copy over the weights.
layer_names = [layer.name for layer in custom_model.layers]

for layer in unet_model.layers:
    if(layer.name in layer_names):
        # swap the weights
        unet_model.layer.weights = custom_model.layer.weights

You get the basic idea..

tonyboston-au commented 3 years ago

Thanks Jordan - I get the idea but the implementation is more difficult! For the segmentation_models unet model: BACKBONE = 'resnet50' model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=n_classes, activation=activation, input_shape=(None, None, N)) you can get the layer names using: for layer in model.layers: print (layer.name) The names look like this: data bn_data zero_padding2d_1 conv0 bn0 relu0 zero_padding2d_2 pooling0 stage1_unit1_bn1 stage1_unit1_relu1 stage1_unit1_conv1 stage1_unit1_bn2 stage1_unit1_relu2 zero_padding2d_3 stage1_unit1_conv2 stage1_unit1_bn3 stage1_unit1_relu3 stage1_unit1_conv3 stage1_unit1_sc add_1 stage1_unit2_bn1... etc For the ResNet50 model, the layer names are: input_1 conv1_pad conv1 bn_conv1 activation_1 pool1_pad max_pooling2d_1 res2a_branch2a bn2a_branch2a activation_2 res2a_branch2b bn2a_branch2b activation_3 res2a_branch2c res2a_branch1 bn2a_branch2c bn2a_branch1 add_17... etc So the names don't match, not sure how to link them and not sure whether all or just some of the weights need to be transferred over?

Any advice is appreciated...

JordanMakesMaps commented 3 years ago

Hi @tonyboston-au because the encode (i.e., resnet50) is within the Unet structure, keras treats it as a layer (thus ignoring the names of all of the resnet50 layer names). Can you confirm that the layer names you're showing are the resnet50 layers within the unet model, or just the layer names of the unet model?

It would be something like:

layer_index_for_resnet_50 = N

# print all resnet layers
for layer in unet.layers[layer_index_for_resnet_50].layers:
   print(layer.name)
JordanMakesMaps commented 3 years ago

Also, see #36, they discuss the same thing but for a different reason.

JordanMakesMaps commented 3 years ago

My apologies! I see now that the UNet model doesn't treat the encoder as a layer (derp, not enough coffee) and understand the difficulties with the differences in naming. However, although the names are different, they correspond to the same layers otherwise quvbel wouldn't have been able to load the imagenet weights into the encoders.

tonyboston-au commented 3 years ago

Can't work out how the layer names correspond as they look so different. Unet layer names are neat and tidy but ResNet50 names are messy... Unet layer names: data bn_data zero_padding2d_1 conv0 bn0 relu0 zero_padding2d_2 pooling0 stage1_unit1_bn1 stage1_unit1_relu1 stage1_unit1_conv1 stage1_unit1_bn2 stage1_unit1_relu2 zero_padding2d_3 stage1_unit1_conv2 stage1_unit1_bn3 stage1_unit1_relu3 stage1_unit1_conv3 stage1_unit1_sc add_1 stage1_unit2_bn1 stage1_unit2_relu1 stage1_unit2_conv1 stage1_unit2_bn2 stage1_unit2_relu2 zero_padding2d_4 stage1_unit2_conv2 stage1_unit2_bn3 stage1_unit2_relu3 stage1_unit2_conv3 add_2 stage1_unit3_bn1 stage1_unit3_relu1 stage1_unit3_conv1 stage1_unit3_bn2 stage1_unit3_relu2 zero_padding2d_5 stage1_unit3_conv2 stage1_unit3_bn3 stage1_unit3_relu3 stage1_unit3_conv3 add_3 stage2_unit1_bn1 stage2_unit1_relu1 stage2_unit1_conv1 stage2_unit1_bn2 stage2_unit1_relu2 zero_padding2d_6 stage2_unit1_conv2 stage2_unit1_bn3 stage2_unit1_relu3 stage2_unit1_conv3 stage2_unit1_sc add_4 stage2_unit2_bn1 stage2_unit2_relu1 stage2_unit2_conv1 stage2_unit2_bn2 stage2_unit2_relu2 zero_padding2d_7 stage2_unit2_conv2 stage2_unit2_bn3 stage2_unit2_relu3 stage2_unit2_conv3 add_5 stage2_unit3_bn1 stage2_unit3_relu1 stage2_unit3_conv1 stage2_unit3_bn2 stage2_unit3_relu2 zero_padding2d_8 stage2_unit3_conv2 stage2_unit3_bn3 stage2_unit3_relu3 stage2_unit3_conv3 add_6 stage2_unit4_bn1 stage2_unit4_relu1 stage2_unit4_conv1 stage2_unit4_bn2 stage2_unit4_relu2 zero_padding2d_9 stage2_unit4_conv2 stage2_unit4_bn3 stage2_unit4_relu3 stage2_unit4_conv3 add_7 stage3_unit1_bn1 stage3_unit1_relu1 stage3_unit1_conv1 stage3_unit1_bn2 stage3_unit1_relu2 zero_padding2d_10 stage3_unit1_conv2 stage3_unit1_bn3 stage3_unit1_relu3 stage3_unit1_conv3 stage3_unit1_sc add_8 stage3_unit2_bn1 stage3_unit2_relu1 stage3_unit2_conv1 stage3_unit2_bn2 stage3_unit2_relu2 zero_padding2d_11 stage3_unit2_conv2 stage3_unit2_bn3 stage3_unit2_relu3 stage3_unit2_conv3 add_9 stage3_unit3_bn1 stage3_unit3_relu1 stage3_unit3_conv1 stage3_unit3_bn2 stage3_unit3_relu2 zero_padding2d_12 stage3_unit3_conv2 stage3_unit3_bn3 stage3_unit3_relu3 stage3_unit3_conv3 add_10 stage3_unit4_bn1 stage3_unit4_relu1 stage3_unit4_conv1 stage3_unit4_bn2 stage3_unit4_relu2 zero_padding2d_13 stage3_unit4_conv2 stage3_unit4_bn3 stage3_unit4_relu3 stage3_unit4_conv3 add_11 stage3_unit5_bn1 stage3_unit5_relu1 stage3_unit5_conv1 stage3_unit5_bn2 stage3_unit5_relu2 zero_padding2d_14 stage3_unit5_conv2 stage3_unit5_bn3 stage3_unit5_relu3 stage3_unit5_conv3 add_12 stage3_unit6_bn1 stage3_unit6_relu1 stage3_unit6_conv1 stage3_unit6_bn2 stage3_unit6_relu2 zero_padding2d_15 stage3_unit6_conv2 stage3_unit6_bn3 stage3_unit6_relu3 stage3_unit6_conv3 add_13 stage4_unit1_bn1 stage4_unit1_relu1 stage4_unit1_conv1 stage4_unit1_bn2 stage4_unit1_relu2 zero_padding2d_16 stage4_unit1_conv2 stage4_unit1_bn3 stage4_unit1_relu3 stage4_unit1_conv3 stage4_unit1_sc add_14 stage4_unit2_bn1 stage4_unit2_relu1 stage4_unit2_conv1 stage4_unit2_bn2 stage4_unit2_relu2 zero_padding2d_17 stage4_unit2_conv2 stage4_unit2_bn3 stage4_unit2_relu3 stage4_unit2_conv3 add_15 stage4_unit3_bn1 stage4_unit3_relu1 stage4_unit3_conv1 stage4_unit3_bn2 stage4_unit3_relu2 zero_padding2d_18 stage4_unit3_conv2 stage4_unit3_bn3 stage4_unit3_relu3 stage4_unit3_conv3 add_16 bn1 relu1 decoder_stage0_upsampling decoder_stage0_concat decoder_stage0a_conv decoder_stage0a_bn decoder_stage0a_relu decoder_stage0b_conv decoder_stage0b_bn decoder_stage0b_relu decoder_stage1_upsampling decoder_stage1_concat decoder_stage1a_conv decoder_stage1a_bn decoder_stage1a_relu decoder_stage1b_conv decoder_stage1b_bn decoder_stage1b_relu decoder_stage2_upsampling decoder_stage2_concat decoder_stage2a_conv decoder_stage2a_bn decoder_stage2a_relu decoder_stage2b_conv decoder_stage2b_bn decoder_stage2b_relu decoder_stage3_upsampling decoder_stage3_concat decoder_stage3a_conv decoder_stage3a_bn decoder_stage3a_relu decoder_stage3b_conv decoder_stage3b_bn decoder_stage3b_relu decoder_stage4_upsampling decoder_stage4a_conv decoder_stage4a_bn decoder_stage4a_relu decoder_stage4b_conv decoder_stage4b_bn decoder_stage4b_relu final_conv softmax ResNet50 layer names: input_1 conv1_pad conv1_conv conv1_bn conv1_relu pool1_pad pool1_pool conv2_block1_1_conv conv2_block1_1_bn conv2_block1_1_relu conv2_block1_2_conv conv2_block1_2_bn conv2_block1_2_relu conv2_block1_0_conv conv2_block1_3_conv conv2_block1_0_bn conv2_block1_3_bn conv2_block1_add conv2_block1_out conv2_block2_1_conv conv2_block2_1_bn conv2_block2_1_relu conv2_block2_2_conv conv2_block2_2_bn conv2_block2_2_relu conv2_block2_3_conv conv2_block2_3_bn conv2_block2_add conv2_block2_out conv2_block3_1_conv conv2_block3_1_bn conv2_block3_1_relu conv2_block3_2_conv conv2_block3_2_bn conv2_block3_2_relu conv2_block3_3_conv conv2_block3_3_bn conv2_block3_add conv2_block3_out conv3_block1_1_conv conv3_block1_1_bn conv3_block1_1_relu conv3_block1_2_conv conv3_block1_2_bn conv3_block1_2_relu conv3_block1_0_conv conv3_block1_3_conv conv3_block1_0_bn conv3_block1_3_bn conv3_block1_add conv3_block1_out conv3_block2_1_conv conv3_block2_1_bn conv3_block2_1_relu conv3_block2_2_conv conv3_block2_2_bn conv3_block2_2_relu conv3_block2_3_conv conv3_block2_3_bn conv3_block2_add conv3_block2_out conv3_block3_1_conv conv3_block3_1_bn conv3_block3_1_relu conv3_block3_2_conv conv3_block3_2_bn conv3_block3_2_relu conv3_block3_3_conv conv3_block3_3_bn conv3_block3_add conv3_block3_out conv3_block4_1_conv conv3_block4_1_bn conv3_block4_1_relu conv3_block4_2_conv conv3_block4_2_bn conv3_block4_2_relu conv3_block4_3_conv conv3_block4_3_bn conv3_block4_add conv3_block4_out conv4_block1_1_conv conv4_block1_1_bn conv4_block1_1_relu conv4_block1_2_conv conv4_block1_2_bn conv4_block1_2_relu conv4_block1_0_conv conv4_block1_3_conv conv4_block1_0_bn conv4_block1_3_bn conv4_block1_add conv4_block1_out conv4_block2_1_conv conv4_block2_1_bn conv4_block2_1_relu conv4_block2_2_conv conv4_block2_2_bn conv4_block2_2_relu conv4_block2_3_conv conv4_block2_3_bn conv4_block2_add conv4_block2_out conv4_block3_1_conv conv4_block3_1_bn conv4_block3_1_relu conv4_block3_2_conv conv4_block3_2_bn conv4_block3_2_relu conv4_block3_3_conv conv4_block3_3_bn conv4_block3_add conv4_block3_out conv4_block4_1_conv conv4_block4_1_bn conv4_block4_1_relu conv4_block4_2_conv conv4_block4_2_bn conv4_block4_2_relu conv4_block4_3_conv conv4_block4_3_bn conv4_block4_add conv4_block4_out conv4_block5_1_conv conv4_block5_1_bn conv4_block5_1_relu conv4_block5_2_conv conv4_block5_2_bn conv4_block5_2_relu conv4_block5_3_conv conv4_block5_3_bn conv4_block5_add conv4_block5_out conv4_block6_1_conv conv4_block6_1_bn conv4_block6_1_relu conv4_block6_2_conv conv4_block6_2_bn conv4_block6_2_relu conv4_block6_3_conv conv4_block6_3_bn conv4_block6_add conv4_block6_out conv5_block1_1_conv conv5_block1_1_bn conv5_block1_1_relu conv5_block1_2_conv conv5_block1_2_bn conv5_block1_2_relu conv5_block1_0_conv conv5_block1_3_conv conv5_block1_0_bn conv5_block1_3_bn conv5_block1_add conv5_block1_out conv5_block2_1_conv conv5_block2_1_bn conv5_block2_1_relu conv5_block2_2_conv conv5_block2_2_bn conv5_block2_2_relu conv5_block2_3_conv conv5_block2_3_bn conv5_block2_add conv5_block2_out conv5_block3_1_conv conv5_block3_1_bn conv5_block3_1_relu conv5_block3_2_conv conv5_block3_2_bn conv5_block3_2_relu conv5_block3_3_conv conv5_block3_3_bn conv5_block3_add conv5_block3_out avg_pool predictions There are 177 ResNet50 layer names and 231 U-Net layer names for the full model or 190 for just the encoder part of the U-Net model (with a ResNet50 backbone). I know the mapping to load weights is possible but it doesn't seem to be a simple one to one correspondence of the layers.

pluniak commented 3 years ago

I think it might be easier to re-train your with the ResNet50 implementation of this repo. Alternatively you could add the decoder to your existing model. It's not that difficult.

tonyboston-au commented 3 years ago

Thanks for your help @JordanMakesMaps and @pluniak. @pluniak - these sound like a more doable options. I have move on to other things for now but may come back to it later and try your suggestions...

adolfogc commented 2 years ago

Hi @tonyboston-au,

I don't know if you have already solved this by yourself, but I wanted to share a quick modification you could use, in case it helps:

from segmentation_models.models.unet import build_unet, DecoderUpsamplingX2Block, DecoderTransposeX2Block

def CustomBackboneUnet(backbone,
                       classes=1,
                       activation='sigmoid',
                       encoder_freeze=False,
                       encoder_features=None,
                       decoder_block_type='upsampling',
                       decoder_filters=[256, 128, 64, 32, 16],
                       decoder_use_batchnorm=True,
                       **kwargs):

  if decoder_block_type == 'upsampling':
    decoder_block = DecoderUpsamplingX2Block
  elif decoder_block_type == 'transpose':
    decoder_block = DecoderTransposeX2Block
  else:
    raise ValueError('Decoder block type should be in ("upsampling", "transpose"). '
                    'Got: {}'.format(decoder_block_type))

  if encoder_features is None:
    raise ValueError('Please provide encoder features')

  model = build_unet(
        backbone=backbone,
        decoder_block=decoder_block,
        skip_connection_layers=encoder_features,
        decoder_filters=decoder_filters,
        classes=classes,
        activation=activation,
        n_upsample_blocks=len(decoder_filters),
        use_batchnorm=decoder_use_batchnorm,
    )

  # lock encoder weights for fine-tuning
  if encoder_freeze:
      freeze_model(backbone)

  return model

So, assuming you use the ResNet50 provided by Tensorflow/Keras:

from tensorflow.keras.applications import ResNet50

backbone = ResNet50(weights=model_weights_path,
                    input_shape=(256, 256, 3),
                    include_top=False,
                    pooling=None)

backbone_features = ('conv4_block6_out', 'conv3_block4_out', 'conv2_block3_out', 'conv1_relu')

model = CustomBackboneUnet(backbone, encoder_features=backbone_features, encoder_freeze=True)
tonyboston-au commented 2 years ago

Thanks @adolfogc. I've moved on from this as the differences in performance using weights from different sources appear to be minor, but thanks for your suggestions.