Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.59k stars 510 forks source link

error loading pre-trained weights for fine-tuning YOLONAS #949

Closed albertofernandezvillan closed 1 year ago

albertofernandezvillan commented 1 year ago

Instantiating a YOLONAS model for fine-tuning:

model = models.get('yolo_nas_l', num_classes=len(dataset_params['classes']), pretrained_weights="coco")

Line above works, but following line (just copy/paste the download weights to the current working directory):

model = models.get('yolo_nas_l', num_classes=len(dataset_params['classes']), pretrained_weights="coco", checkpoint_path="./yolo_nas_l_coco.pth")

Throws an error loading the weights of the head (as my dataset has different number of classes). I want to load the weights from a file on disk for fine-tuning.

How to cope with this error?

dagshub[bot] commented 1 year ago

Join the discussion on DagsHub!

Louis-Dupont commented 1 year ago

Hi @albertofernandezvillan

I listed below 2 ways you can load and fine-tune your checkpoint, depending on your case

1. Standard flow

Step 1: Download the model to fine-tune

model = models.get('yolo_nas_m', num_classes=20, pretrained_weights="coco")
Trainer.train(model=model, training_params=...)
...

Step 2. Load fine-tuned model

model = models.get('yolo_nas_m', checkpoint_path="<path-to-fine-tuned-checkpoint-file>")

Note that:

I believe this should cover your case, but I am aware this is not exactly what you asked for.

2. Fine-tuning from a local checkpoint

If you want to load the model before fine-tuning it (basically step 1) from a local checkpoint, and this local checkpoint has a different number of heads than what you want, this is the way to do:

model = models.get('yolo_nas_m', num_classes=20, checkpoint_num_classes=80, checkpoint_path=".cache/torch/hub/checkpoints/yolo_nas_m_coco.pth")

Step 2. remains the same.

Hoping this will help you!