Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.46k stars 481 forks source link

Load Model From The Disk #1423

Closed adilraja closed 3 months ago

adilraja commented 10 months ago

🚀 Feature Request

Hi, I am trying to train yolo_nas_l.pt etc on a custom project. Somehow I have to load the model from the disk and I can't use super_gradients.training.models.get(...) to get the model from the web in my script which I have to run on a remote server. The remote server does not allow me to fetch the model from the Internet through a python script. So I have to have it on the disk. I am wondering if there is a method in super_gradients that would allow me to load a .pt model from the disk? I tried to run torch.load(...) but that failed with an error code. So it would be nice to have this feature in supergradients.

Best regards, Muhammad Adil Raja

Proposed Solution (Optional)

Perhaps it would be nice to have a get function which loads model from the disk with a slightly different signature.

bit-scientist commented 10 months ago

Hello, let us see what error you got while loading the model from the disk. Also, consider leaving your env details such OS, package versions, etc.

adilraja commented 10 months ago

Hi all, I get this error when I try to do torch.load('yolo_nas.l.pt') from the local disk. I wonder if you can help me with this.

AttributeError: 'YoloNASBottleneck' object has no attribute 'drop_path

Best regards, Dr. Muhammad Adil Raja Postdoctoral researcher Regulated Software Research Centre (RSRC) Dundalk Institute of Technology (DkIT) Ireland

On Mon, Aug 28, 2023 at 6:17 AM bit-scientist @.***> wrote:

Hello, let us see what error you got while loading the model from the disk. Also, consider leaving your env details such OS, package versions, etc.

— Reply to this email directly, view it on GitHub https://github.com/Deci-AI/super-gradients/issues/1423#issuecomment-1695024877, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADS3RIRGI7JAZIFY7UE4WW3XXQSYNANCNFSM6AAAAAA4A3QQKI . You are receiving this because you authored the thread.Message ID: @.***>

BloodAxe commented 10 months ago

Can you please provide the exact code you are using to load the model from disk? It looks like you are doing something wrong, as loading checkpoints from file is indeed supported and described here. Check the checkpoint_path. Note that you need to pass both num_classes and checkpoint_path simultaneously.

adilraja commented 10 months ago

Hi Eugene, Many thanks for this. Here is my code:

import os import torch

from ultralytics import NAS

from super_gradients.training import models from super_gradients.training import Trainer from super_gradients.training.losses import PPYoloELoss from super_gradients.training.metrics import DetectionMetrics_050 from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback

HOME = os.getcwd() print(HOME)

DEVICE = 'cuda' if torch.cuda.is_available() else "cpu" print("The device is: ", DEVICE) print(DEVICE) MODEL_ARCH = 'yolo_nas_l'

model = torch.load('yolo_nas_l.pt', map_location=torch.device(DEVICE))

.... (And then I do:)

trainer.train( model=model, training_params=train_params, train_loader=train_data, valid_loader=val_data )

And this is where the error occurs finally.

By the way, thanks indeed for letting me know that checkpoints from file is indeed supported and described here https://docs.deci.ai/super-gradients/docstring/training/models.html#training.models.model_factory.get. Check the checkpoint_path. Note that you need to pass both num_classes and checkpoint_path simultaneously.

I will be sure to check this. But I have an impression that through the checkpint_path we can load a .pth file. I want to load the .pt file. I wonder if they are any different.

Best regards, Dr. Muhammad Adil Raja Postdoctoral researcher Regulated Software Research Centre (RSRC) Dundalk Institute of Technology (DkIT) Ireland

On Mon, Aug 28, 2023 at 7:54 AM Eugene Khvedchenya @.***> wrote:

Can you please provide the exact code you are using to load the model from disk? It looks like you are doing something wrong, as loading checkpoints from file is indeed supported and described here https://docs.deci.ai/super-gradients/docstring/training/models.html#training.models.model_factory.get. Check the checkpoint_path. Note that you need to pass both num_classes and checkpoint_path simultaneously.

— Reply to this email directly, view it on GitHub https://github.com/Deci-AI/super-gradients/issues/1423#issuecomment-1695133198, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADS3RITSEUJKHJBSUFGBIFDXXQ6CPANCNFSM6AAAAAA4A3QQKI . You are receiving this because you authored the thread.Message ID: @.***>

adilraja commented 10 months ago

Hi Eugene, Ok so I ran my script according to your advice and I got the following error trace.

[2023-08-28 12:29:48] INFO - crash_tips_setup.py - Crash tips is enabled. You can set your environment variable to CRASH_HANDLER=FALSE to disable it [2023-08-28 12:29:52] WARNING - init.py - Failed to import pytorch_quantization /ichec/work/dkcom001c/conda/yolonas/lib/python3.10/site-packages/_distutils_hack/init.py:33: UserWarning: Setuptools is replacing distutils. warnings.warn("Setuptools is replacing distutils.") [2023-08-28 12:29:56] WARNING - calibrator.py - Failed to import pytorch_quantization [2023-08-28 12:29:56] WARNING - export.py - Failed to import pytorch_quantization [2023-08-28 12:29:56] WARNING - selective_quantization_utils.py - Failed to import pytorch_quantization Traceback (most recent call last): File "/ichec/home/users/madil/yolonas/yolonastrainer.py", line 27, in

model = models.get(MODEL_ARCH, num_classes=len(CLASSES), checkpoint_path="yolo_nas_l.pt").to(DEVICE) File "/ichec/work/dkcom001c/conda/yolonas/lib/python3.10/site-packages/super_gradients/training/models/model_factory.py", line 230, in get ckpt_entries = read_ckpt_state_dict(ckpt_path=checkpoint_path).keys() File "/ichec/work/dkcom001c/conda/yolonas/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__ raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'YoloNAS_L' object has no attribute 'keys' I used the following to get the model: model = models.get(MODEL_ARCH, num_classes=len(CLASSES), checkpoint_path=" yolo_nas_l.pt").to(DEVICE) And the following to train it: trainer.train( model=model, training_params=train_params, train_loader=train_data, valid_loader=val_data ) Information about my conda environment is: (yolonas) ***@***.*** yolonas]$ pip list Package Version ----------------------------- ---------- absl-py 1.4.0 alabaster 0.7.13 antlr4-python3-runtime 4.9.3 attrs 23.1.0 Babel 2.12.1 boto3 1.28.35 botocore 1.31.35 build 0.10.0 cachetools 5.3.1 certifi 2023.7.22 charset-normalizer 3.2.0 click 8.1.7 cmake 3.27.2 coloredlogs 15.0.1 contourpy 1.1.0 coverage 5.3.1 cycler 0.11.0 Deprecated 1.2.14 docutils 0.17.1 einops 0.3.2 filelock 3.12.2 flatbuffers 23.5.26 fonttools 4.42.1 future 0.18.3 google-auth 2.22.0 google-auth-oauthlib 1.0.0 grpcio 1.57.0 humanfriendly 10.0 hydra-core 1.3.2 (yolonas) ***@***.*** yolonas]$ pip list Package Version ----------------------------- ---------- absl-py 1.4.0 alabaster 0.7.13 antlr4-python3-runtime 4.9.3 attrs 23.1.0 Babel 2.12.1 boto3 1.28.35 botocore 1.31.35 build 0.10.0 cachetools 5.3.1 certifi 2023.7.22 charset-normalizer 3.2.0 click 8.1.7 cmake 3.27.2 coloredlogs 15.0.1 contourpy 1.1.0 coverage 5.3.1 cycler 0.11.0 Deprecated 1.2.14 docutils 0.17.1 einops 0.3.2 filelock 3.12.2 flatbuffers 23.5.26 fonttools 4.42.1 future 0.18.3 google-auth 2.22.0 google-auth-oauthlib 1.0.0 grpcio 1.57.0 humanfriendly 10.0 hydra-core 1.3.2 opencv-python 4.8.0.76 opencv-python-headless 4.8.0.76 packaging 23.1 Pillow 9.5.0 pip 23.2.1 pip-tools 7.3.0 protobuf 3.20.3 psutil 5.9.5 pyasn1 0.5.0 pyasn1-modules 0.3.0 pycocotools 2.0.6 pyDeprecate 0.3.2 Pygments 2.16.1 pyparsing 2.4.5 pyproject_hooks 1.0.0 python-dateutil 2.8.2 PyYAML 6.0.1 rapidfuzz 3.2.0 referencing 0.30.2 requests 2.31.0 requests-oauthlib 1.3.1 rich 13.5.2 rpds-py 0.9.2 rsa 4.9 s3transfer 0.6.2 scipy 1.11.2 setuptools 68.0.0 six 1.16.0 snowballstemmer 2.2.0 Sphinx 4.0.3 sphinx-rtd-theme 1.3.0 sphinxcontrib-applehelp 1.0.4 sphinxcontrib-devhelp 1.0.2 sphinxcontrib-htmlhelp 2.0.1 sphinxcontrib-jquery 4.1 sphinxcontrib-jsmath 1.0.1 sphinxcontrib-qthelp 1.0.3 sphinxcontrib-serializinghtml 1.1.5 stringcase 1.2.0 super-gradients 3.2.0 supervision 0.13.0 sympy 1.12 tensorboard 2.14.0 tensorboard-data-server 0.7.1 termcolor 1.1.0 tomli 2.0.1 torch 2.0.1 torchmetrics 0.8.0 torchvision 0.15.2 tqdm 4.66.1 treelib 1.6.1 triton 2.0.0 typing_extensions 4.7.1 urllib3 1.26.16 Werkzeug 2.3.7 wheel 0.38.4 wrapt 1.15.0 (yolonas) ***@***.*** yolonas]$ python -V Python 3.10.12 By the way, I was able to train yolo_nas_s.pt on my local machine (which has a 6 GB memory Nvidia card). But I failed to train it on the remote server (which has much bigger cards). Best regards, Dr. Muhammad Adil Raja Postdoctoral researcher Regulated Software Research Centre (RSRC) Dundalk Institute of Technology (DkIT) Ireland On Mon, Aug 28, 2023 at 7:54 AM Eugene Khvedchenya ***@***.***> wrote: > Can you please provide the exact code you are using to load the model from > disk? > It looks like you are doing something wrong, as loading checkpoints from > file is indeed supported and described here > . > Check the checkpoint_path. Note that you need to pass both num_classes > and checkpoint_path simultaneously. > > — > Reply to this email directly, view it on GitHub > , > or unsubscribe > > . > You are receiving this because you authored the thread.Message ID: > ***@***.***> >
23pointsNorth commented 10 months ago

Note that you need to pass both num_classes and checkpoint_path simultaneously.

Hi! (jumping in on this thread.) The solution is straight away known in passing num_classes, however, I am wondering if it is necessary? After training a model initialized on e.g. COCO, I load the average_model.pth including the num_classes integer, yet when I make a prediction on a random image, the output of the model contains correctly the property class_names. This suggests that the pth file has stored a string of the classes, and by proxy - the length/count? Or am I missing something that is being re-used in the script.

adilraja commented 10 months ago

Hi Daniel, I am not talking about further training a pth model for the same kind of data. I am talking about retraining a pt model for different type of data. But I also don't know about the differences between the pth and pt models.

Sent from Phone

On Tue, 29 Aug 2023, 2:23 pm Daniel Angelov, @.***> wrote:

Note that you need to pass both num_classes and checkpoint_path simultaneously.

Hi! (jumping in on this thread.) The solution is straight away known in passing num_classes, however, I am wondering if it is necessary? After training a model initialized on e.g. COCO, I load the average_model.pth including the num_classes integer, yet when I make a prediction on a random image, the output of the model contains correctly the property class_names. This suggests that the pth file has stored a string of the classes, and by proxy - the length/count? Or am I missing something that is being re-used in the script.

— Reply to this email directly, view it on GitHub https://github.com/Deci-AI/super-gradients/issues/1423#issuecomment-1697439014, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADS3RIUBWTQZXYLRTIEWYR3XXXUMPANCNFSM6AAAAAA4A3QQKI . You are receiving this because you authored the thread.Message ID: @.***>

shaydeci commented 3 months ago

It seems to me that you are trying to load a chackpoint you have created on your own with the models.get().

It is quite hard to know what went wrong when that is the case. We do however, test that checkpoints created throughout the training in SG can be loaded that way in our unit tests. I recommend going over our checkpoints docs section here. I will close this issue for now, and if further problems exist please feel free to re-open it.