abdelaziz-mahdy / pytorch_lite

flutter package to help run pytorch lite models classification and YoloV5 and YoloV8.
MIT License
58 stars 25 forks source link

Error in implementing Custom ResNet50 model #82

Open MeDenTec opened 1 month ago

MeDenTec commented 1 month ago

I am trying to deploy my ResNet-50 custom trained classification model using pytorch lite library. The model works well on python but giving false predictions on Flutter, always giving high probs for the first class.

I optimized my model using this code provided in the Readme

from torch.utils.mobile_optimizer import optimize_for_mobile
best_model = best_model.to('cpu')

best_model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(best_model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("Fitrzpack_model.pt")

I don't know the exact reason but may this is due to the transformation I have to apply to my images during prediction or something else. I have also provided the transformation code below.

transform = transforms.Compose([transforms.ToPILImage(),
                               transforms.ToTensor(),
                               transforms.Resize((IMG_SIZE, IMG_SIZE)),
                               transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
                     ])

I am also sharing the code I use to predict the model in python

def predict(x):
    img = Image.open(x).convert("RGB")
    img = transform(np.array(img))
    print("shape of transformed image", img.shape)
    img = img.view(1, 3, 224, 224)
    print("shape of reshaped image", img.shape)
    best_model.eval()
    with torch.no_grad():
        if torch.cuda.is_available():
            img = img.cuda()

        out = best_model(img)

        return out.argmax(1).item()

Please help me getting out of this situation @abdelaziz-mahdy

abdelaziz-mahdy commented 1 month ago

Most probably it's a problem with the input sooo please give both your full prediction code for dart and Python

abdelaziz-mahdy commented 1 month ago

also does it have the same results on android and ios?

MeDenTec commented 1 month ago

Hi @abdelaziz-mahdy, Thank you so much for your response.

This is how I am loading the model

String pathImageModel =
        "assets/models/best_Fitrzpack_withNorm_extraTransformations_Adam_epoch_22.pt";
    try {
      _imageModel = await PytorchLite.loadClassificationModel(
          pathImageModel, 224, 224, 5,
          labelPath: "assets/labels/label_classification_imageNet.txt");

And this is this how I am making prediction in Dart, I also tried with default mean and std but results were similar.

textToShow = await _imageModel!.getImagePrediction(
        await File(image.path).readAsBytes(),
        mean: [0.485, 0.456, 0.406],
        std: [0.229, 0.224, 0.225]);
    textToShow = "${textToShow ?? ""}, ${inferenceTimeAsString(stopwatch)}";

Here you can also see the output of the prediction list, the first value is always higher hence always prediction first class in every image

[37.45313262939453, 20.606271743774414, 19.992067337036133, 9.70131778717041, 32.07078552246094, -79.898193359375, -94.40465545654297, -99.96098327636719, -95.91107940673828, -102.2623291015625, -94.18521881103516, -104.00025939941406, -97.89167022705078, -102.9541244506836, -101.01667785644531, -101.24066925048828, -103.78877258300781, -96.97942352294922, -94.73454284667969, -104.4057388305664, -98.69512939453125, -96.98098754882812, -95.92950439453125, -99.71725463867188, -36.406681060791016, -95.3742446899414, -93.91866302490234, -93.91046142578125, -97.62044525146484, -78.75592803955078, -96.0226821899414, -100.22040557861328, -108.44017791748047, -81.2840805053711, -96.94139862060547, -92.8824234008789, -102.83641815185547, -104.82340240478516, -91.39096069335938, -103.66374206542969, -104.43321990966797, -106.07219696044922, -106.11819458007812, -91.6676254272461, -99.1351089477539, -102.29850006103516, -99.13842010498047, -90.38143157958984, -106.93729400634766, -90.3895263671875, -93.90796661376953,

Python codes

Below is the python code used for model loading and export


# Model innitialization
resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_ftrs = resnet.fc.in_features
resnet.fc.in_features = nn.Linear(num_ftrs, OUT_CLASSES)

model = deepcopy(resnet)

# Custom trained model loading
model_path ="models/pytorch_skin_types/best_Fitrzpack_withNorm_extraTransformations_Adam_epoch_22.pth"
best_model = model
best_model.load_state_dict(torch.load(model_path,weights_only=True))
best_model.eval()

# Export code used to optimize the model
from torch.utils.mobile_optimizer import optimize_for_mobile
best_model = best_model.to('cpu')

best_model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(best_model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("Fitrzpack_model.pt")

Below is the prediction function used in python

label_index = {"Phototype_I_&_II": 0, "Phototype_III": 1, "Phototype_IV": 2, "Phototype_V": 3, "Phototype_VI": 4}
index_label = {0 :"Phototype_I_&_II", 1 : "Phototype_III", 2 : "Phototype_IV", 3 : "Phototype_V", 4 : "Phototype_VI"}

# Transformations
transform = transforms.Compose([transforms.ToPILImage(),
                               transforms.ToTensor(),
                               transforms.Resize((IMG_SIZE, IMG_SIZE)),
                               transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
                     ])

# prediction function
def predict(x):
    img = Image.open(x).convert("RGB")
    img = transform(np.array(img))
    print("shape of transformed image", img.shape)
    img = img.view(1, 3, 224, 224)
    print("shape of reshaped image", img.shape)
    best_model.eval()
    with torch.no_grad():
        if torch.cuda.is_available():
            img = img.cuda()

        out = best_model(img)

        return out.argmax(1).item()

print(index_label[predict(r"D:\262541-Faizan\misc_projects\face_analyzer\Dataset\Fitrzpack_test\black.jpg")])
print(index_label[predict(r"D:\262541-Faizan\misc_projects\face_analyzer\Dataset\Fitrzpack_test\white.jpg")])
abdelaziz-mahdy commented 1 month ago

If possible can you provide the model and a test image try it out and make sure whatever fix I do works

MeDenTec commented 1 month ago

https://drive.google.com/drive/folders/1ncrauuM5S1VcSIDu485kzQgX46RdOH_P?usp=sharing

Sure , you can get the model and the training notebook in above drive folder.

abdelaziz-mahdy commented 1 month ago

Will check it out, I can't promise I will be able to figure it out or not but will try it

MeDenTec commented 1 month ago

Yeah, sure Have you ever deployed ResNet50 using this library ? You can also share that if possible

abdelaziz-mahdy commented 1 month ago

Me no, but any classification model should work correctly, since it's just using pytorch

The only problem is the input has to be the same as the python one so same format RGB Same normalization Same image encoding I think too

So alot of input parameters to do

abdelaziz-mahdy commented 1 month ago

Also did you try the other preprocessing enum?

MeDenTec commented 1 month ago

I just tried with Standard mean and Std dev of ImageNet which I also used for in python. I also tried what th default parameters, but it always predicts the first class

MeDenTec commented 1 month ago
transform = transforms.Compose([transforms.ToPILImage(),
                               transforms.ToTensor(),
                               transforms.Resize((IMG_SIZE, IMG_SIZE)),
                               transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
                     ])

considering this python code for transformations, how would you suggest the prediction function to be called ?

Another concern is that my model is not converting into scripted_model.pt properly.

abdelaziz-mahdy commented 1 month ago

The transformation mentioned is the same as the one that package does

I don't know if not converting will hurt or not I only tested with converted models

abdelaziz-mahdy commented 1 month ago

ok so i went through both python code and dart code (i cant try it since i dont have the labels for the test images)

did you try using

    List<double?>? predictionList = await _imageModel!.getImagePredictionList(
      await File(image.path).readAsBytes(),
      preProcessingMethod: PreProcessingMethod.imageLib
    );

and

    List<double?>? predictionList = await _imageModel!.getImagePredictionList(
      await File(image.path).readAsBytes(),
      preProcessingMethod: PreProcessingMethod.native
    );

did both give the same results? since from i am currently seeing everything should work correctly

also can you load the model file in python and try it to make sure the downloaded model is correct?

MeDenTec commented 1 month ago

Yes, the model works well in python. I have checked it multiple times.I have also provided the code and model in the above link. you can also check. labels are also there

MeDenTec commented 1 month ago

The test images i provided is random. you just need to check if its predicting anything other than the first class.

abdelaziz-mahdy commented 1 month ago

I don't know the correct label for them, so that will not help😅

MeDenTec commented 1 month ago

I skipped the optimized_traced_model = optimize_for_mobile(traced_script_module) line and it worked somehow.

I also wanted to know how can we run YOLOv8 classification models using Flutter pytorch lite library. I tried with the same torchscript method used for YOLOv8 object detection but it doesn't worked out.

abdelaziz-mahdy commented 1 month ago

I am glad that your model worked, I was stuck not figuring out why it was failing to work

Sorry for not being to help

For yolov8 did you create it using the yolo command provided in the readme?

MeDenTec commented 1 month ago

Thanks a lot for your help, I used this code, but its not supported

model.export(format="", imgsz = 320, optimize = True)

Dont know how we can run YOLOv8 classifiation models

abdelaziz-mahdy commented 1 month ago

What do you mean by not supported? It doesn't work? And if there is an error please share it

MeDenTec commented 1 month ago

There was an error in model loading in Flutter. May be we need to convert YOLO classification model differently than object detection model. I used the code below which is obviously for YOLOv8 object detection but the conversion method for YOLOv8 classification model is not explicitly mentioned in the Readme.

!yolo mode=export model="your model" format=torchscript optimize
abdelaziz-mahdy commented 1 month ago

did you try model.export(format="torchscript", imgsz = 320, optimize = True)

abdelaziz-mahdy commented 1 month ago

@MeDenTec did you try the above export? if yes and it works please let me know to close this issue

MeDenTec commented 1 month ago

model.export(format="torchscript", imgsz = 320, optimize = True)

Yes, I tried the similar export method, but it didn't worked. I thinks its only for Object detection models not YOLO image classification models. Don't know how can I run yolo image classification models.

abdelaziz-mahdy commented 1 month ago

It's in the official docs of yolov8 so this weird

What didn't work exactly? What were the errors ?

Please provide as much info as possible since I can't fix something I don't know 😅