mindspore-lab / mindcv

A toolbox of vision models and algorithms based on MindSpore
https://mindspore-lab.github.io/mindcv/
Apache License 2.0
237 stars 144 forks source link

loading pretrained weights error when num_classes is not 1000 #697

Closed XianyunSun closed 3 months ago

XianyunSun commented 1 year ago

Describe the bug/ 问题描述 (Mandatory / 必填) I run into the error when running: net = mindcv.create_model('convnext_tiny', pretrained=True, num_classes=1) where the shape of the parameters of the last classification layer do not match, but they should not be loaded when num_class is not 1000, according to the demo. Although classification.weight and classification.bais is popped out when loading parameters, the remaining adam_v.classification.weight, adam_v.classification.bais, adam_m.classification.weight and adam_m.classification.bais parameters are still loaded, which caused this problem.

I solved this problem by changing a few lines in mindcv/models/helpers.py:

  1. I changed the _search_param_name to return a list containing all parameter names that contain classification.weight or classification.bais:
    def _search_param_name(params_names: List, param_name: str) -> str:
    same_param_names = []
    for pi in params_names:
        if param_name in pi:
            same_param_names.append(pi)
    return same_param_names
  2. I changed a part of the load_pretrained function to pop out all parameter names returned by _search_param_name :
    elif num_classes != default_cfg["num_classes"]:
        params_names = list(param_dict.keys())
        same_param_names = _search_param_name(params_names, classifier_name + ".weight")
        for param_name in same_param_names:
            param_dict.pop(param_name, "No Parameter {} in ParamDict".format(param_name))
        same_param_names = _search_param_name(params_names, classifier_name + ".bias")
        for param_name in same_param_names:
            param_dict.pop(param_name, "No Parameter {} in ParamDict".format(param_name))

I'm not sure if these will influence the other parts of the code, but at least they can work for this problem.

The-truthh commented 7 months ago

You can check the parameter list by using net.get_parameters() after running net = mindcv.create_model('convnext_tiny', pretrained=True, num_classes=1). Then you will see the parameter adam_v.classification.weight, adam_v.classification.bais, adam_m.classification.weight and adam_m.classification.bais are not included in the net.

So even if these four parameters are not removed from the parameter list loaded by ckpt, they will not be loaded int the net.