Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.59k stars 510 forks source link

Transfer Learning possible? #1450

Closed Phyrokar closed 1 year ago

Phyrokar commented 1 year ago

💡 Your Question

I would like to know if it is possible to use your own pretrained_weights? I've tried with net = models.get(Models.YOLO_NAS_M, num_classes=3, pretrained_weights="/testing_1/average_model.pth") and get the following error:

ValueError: `pretrained_weights="testing_1/average_model.pth" is not a valid and was not found in that platform. Valid pretrained weights are: "dict_keys(['imagenet', 'imagenet21k', 'coco_segmentation_subclass', 'cityscapes', 'coco', 'coco_pose', 'cifar10'])"

Versions

super-gradients 3.2.0

BloodAxe commented 1 year ago

To use own pretrained weights the right option to use is checkpoint_params.checkpoint_path. If the model you are going to train is not 100% compatible with the weights (Let's say you have changed head design completely) you may want to pass checkpoint_params.strict_load=key_matching argument as well to load only matching layers with matching names & shapes.

A pretrained_weights option is intended for another task and it is specifying a dataset name for which SG will try to pull pretrained weights for a given model.