Project-MONAI / MONAILabel

MONAI Label is an intelligent open source image labeling and learning tool.
https://docs.monai.io/projects/label
Apache License 2.0
617 stars 195 forks source link

Finetune new model by freezing all layers except FC #1298

Open AHarouni opened 1 year ago

AHarouni commented 1 year ago

Is your feature request related to a problem? Please describe. No I think an essential usecase is to be able to fine tune a model. Currently the code can load a model then continue training all the layers. I am looking for how can I freeze all layers except the last FC layer and just train a new FC with less number of classes.

Describe the solution you'd like a simple way to:

  1. pass in the model name/ checkpoint.
  2. Specify the FC layer name to keep training while freezing the rest of the layers

Describe alternatives you've considered I wrote function below to copy weights and freeze layers for segresnet. it keeps the last layer named conv_0.conv_0. I load this model in my init of my app. However, training doesn't converge so I think something is missing

def pruneModelFCLayer(dst_model, src_model, checkptPath):
    checkpoint = torch.load(checkptPath)
    src_model_state_dict = checkpoint.get("model", checkpoint)
    src_model.load_state_dict(src_model_state_dict , strict=False)

    new_model_state_dic, updated_keys , unchanged_keys = copy_model_state( dst_model , src_model
                                    , exclude_vars="conv_0.conv_0", inplace=False)
    print(f"unchanged keys {unchanged_keys}")
src_model_state_dict['conv_final.2.conv.weight'][j, ...]

    dst_model.load_state_dict(new_model_state_dic)  # , strict=load_strict)

    ### stop gradients for the pretrained weights
    for x in dst_model.named_parameters():
        if x[0] in updated_keys:
            x[1].requires_grad = False

    params = generate_param_groups(network=dst_model,layer_matches=[lambda x: x[0] in updated_keys],
                                   match_types=["filter"],lr_values=[1e-4],include_others=False)

    return dst_model ,params
SachidanandAlle commented 1 year ago

I guess this kind of tailoring is always possible. monailabel provides the required interfaces to extend the infer/train actions. not sure if this is an interest of monailabel to support one kind of network with one kind of customization in the network.

For example.. if you are looking to customize the network load with without last layer etc.. or some extra customizaton.. you can always override the method def _get_network(self, device): https://github.com/Project-MONAI/MONAILabel/blob/main/monailabel/tasks/infer/basic_infer.py#L424

AHarouni commented 1 year ago

This issue is mainly for training a network but freezing layers to fine tune the last layers, Not for inference. You are correct, I am able to do that now. Below is my attempts, still working on verifying I inserted the code in the correct spot.

def freezeLayers(network, vars_name, include_exclude):
    src_dict = get_state_dict(network)
    to_skip = {s_key for s_key in src_dict if vars_name and re.compile(vars_name).search(s_key)}
    logger.info(f"--------------------- in freeze layer")
    for name, para in network.named_parameters():
        if include_exclude:
            if name not in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False
        else:
            if name in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False

def optimizer(self, context: Context):
    if self.freeze_layers:
        freezeLayers(self._network, "conv_final_xx", self.train_DeepGrow)    --> this will change the network passed in
        context.network = self._network
    return torch.optim.AdamW(context.network.parameters(), lr=1e-4, weight_decay=1e-5) 

freezeLayers function is adapted from monai core. My ask is to add a new interface function to do this freeze given a name to include or exclude. This should cover 90% of the usecases where user can take large models released by monai label as the total segmentator model and fine tune it on my small dataset of 20 volumes.

tangy5 commented 1 year ago

I feel this request can be a bigger feature request.

This is about how to fine-tune models in MONAI Label, now we can only support training fixed-defined networks. But the prevalent strategy is the GPT-style, users may want to load pre-trained weights and fine-tune models in a custom way, either adjust the output channels, or use cascade networks (e.g., label liver first, then label the tumor based on the liver mask).

We can continue discussing this. I see if MONAI Label can support flexible input and output channels of a network, it will be useful. Freeze layers can be a more developer feature, we may expect developers to be familiar with MONAI Label app code.

AHarouni commented 1 year ago

The ask here is to add some new hooks (method in the interface) where the model is loaded from checkpoint (when user selects pretrained). These hooks would allow the user to modify the way the model is loaded by freezing the layers or doing a custom loading of the weights

diazandr3s commented 1 year ago

I guess this is only one part of the ask. Users should also be able to change the first layer if working on more than one modality. And maybe it is good if the users are able to change the number of labels using the Slicer/OHIF plugin viewer?

AHarouni commented 1 year ago

I figured this one out. Below is the details of what I had to do to finetune totalsegmentator model. @SachidanandAlle, I think the ask now is if we can have the freezeLayers function in monai label and flow of the BasicTrainTask. May be have the load function from the BasicTrainTask be exposed so user the _load_checkpoint instead of being private, so user can do a super().load_checkpoint() then call the freezeLayers instead of hacking it in the optimizer code as I did below. I am just throwing ideas here as I didn't get how the flow of _load_checkpoint works as it append to the train_handles.

I have used the total segmentator model released by monai label as my pretrained model and froze all layers except a new last FC layer for just 8 labels which I trained using just 100 volumes.

Model was trained using segresnet so I had to write my own network inheriting from segresnet as

class SegResNetNewFC(SegResNet):
    def __init__(self,spatial_dims: int = 3,
            init_filters: int = 8,out_channels: int = 2,
            **kwargs,):
        super().__init__(spatial_dims=spatial_dims , init_filters=init_filters
                         ,use_conv_final=False,**kwargs)
        self.conv_final=None # delete this layer
        self.conv_final_xx = self._make_final_conv(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        beforeFC = super().forward(x) ## will build everything without the last FC since we init with use_conv_final =False
        x = self.conv_final_xx(beforeFC)
        return x

This way the load function of monai label would work just fine as I changed the name of the last conv from conv_final to conv_final_xx

Next I had to change the optimizer function of my trainer to freeze the layers as

  def optimizer(self, context: Context):
      freezeLayers(self._network, ["conv_final_xx"], True)
      context.network = self._network
      return torch.optim.AdamW(context.network.parameters(), lr=1e-4, weight_decay=1e-5) ## too small

below is the freezeLayers function that is called

def freezeLayers(network, exclude_vars, exclude_include=True):
    src_dict = get_state_dict(network)
    to_skip={''} # initalize single string dict 
    for exclude_var in exclude_vars:
        s = {s_key for s_key in src_dict if exclude_var and re.compile(exclude_var).search(s_key)}
        to_skip.update(s)
    to_skip.remove('')  ## remove the empty string
    logger.error(f"--------------------- layer freeze with {exclude_include=}")
    for name, para in network.named_parameters():
        if exclude_include:
            if name not in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False
            else:
                logger.info(f"----training ------- {name} size {para.shape}")
        else:
            if name in to_skip:
                logger.info(f"------------ freezing {name} size {para.shape}")
                para.requires_grad = False
            else:
                logger.info(f"----training ------- {name} size {para.shape}")
SachidanandAlle commented 1 year ago

_load_checkpoint is protected not private.. so you should be able to overload in derived class