Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.54k stars 496 forks source link

RuntimeError `set_dataset_processing_params()` even after setting such params before `predict()` #1739

Closed ani-mal closed 9 months ago

ani-mal commented 9 months ago

💡 Your Question

I am trying to run model.predict() on a custom trained from scratch model. I am getting the following error:

raise RuntimeError(
RuntimeError: You must set the dataset processing parameters before calling predict.
Please call `model.set_dataset_processing_params(...)` first.

However, I am setting the set_processin_params(..) before running the predict() method, and I still get this error.

here is the implementation:

class YoloNAS(BaseModel):
    def __init__(self, model_architecture, path, max_detections_per_image, iou=0.35, conf=0.25):
        super().__init__('yolo_nas', path, max_detections_per_image)

        self.model = models.get(model_architecture, num_classes=len(self.classes), checkpoint_path=path)
        self.model.to(self.device)
        self.model.set_dataset_processing_params(   class_names=classes_list,
                                                    image_processor=models.get(model_architecture, num_classes=len(self.classes), checkpoint_path=path)._image_processor,
                                                    iou=iou, conf=conf )

    def predict(self, frames):
        predictions = []
        for frame in frames:
            frame_predictions = self.model.predict(frame)

and here are the super-gradient version and some other dependencies of my conda env:

super-gradients           3.5.0                    pypi_0    pypi
pytorch-cuda              11.7                 h778d358_5    pytorch
cuda-runtime              11.7.1                        0    nvidia
data-gradients            0.3.1                    pypi_0    pypi

Versions

No response

BloodAxe commented 9 months ago

Can you please print the content of checkpoint["processing_params"] of the checkpoint you are trying to load into the model? That would be the first thing to check - whether checkpoint indeed has the preprocessing params.

I don't quite get why you need this call:

        self.model.set_dataset_processing_params(   class_names=classes_list,
                                                    image_processor=models.get(model_architecture, num_classes=len(self.classes), checkpoint_path=path)._image_processor,
                                                    iou=iou, conf=conf )

This should be self-sufficient to load everything (model weights, preprocessing params) with a single line (Assuming checkpoint contains this data): self.model = models.get(model_architecture, num_classes=len(self.classes), checkpoint_path=path)

If the checkpoint does not have saved preprocessing params in the first place, this models.get(model_architecture, num_classes=len(self.classes), checkpoint_path=path)._image_processor, would return None anyway.

So let's start from checking whether checkpoint indeed has all the required transforms. If you can provide additional information (training recipe, the SG version you used to train the model that produced this checkpoint or any other relevant information) that would help.

ani-mal commented 9 months ago

@BloodAxe I think I figured out the issue and it is related to the previous issues I posted.

I was having trouble originally loading a previously trained model to further finetune it: https://github.com/Deci-AI/super-gradients/issues/1570

I ended up saving the model like this:

state_dict = model.state_dict()
checkpoint = { "net": state_dict }
torch.save(checkpoint, "checkpoint.pth)

I removed all this custom torch.save(..) logic, and figured out how to leverage the built in way super-gradients is saving the checkpoint to a specified location. I think we had issues in the past figuring out to which path on the remote cluster we could save the files since we dont have access to any path, it has to be specific location on the cluster where we have write access.

Now the model.pth dictionary has all the fields super-gradient is expecting :D

Thanks again for your time and help!