abdelaziz-mahdy / pytorch_lite

flutter package to help run pytorch lite models classification and YoloV5 and YoloV8.
MIT License
52 stars 23 forks source link

Error in implementing Custom ResNet50 model #82

Open MeDenTec opened 1 day ago

MeDenTec commented 1 day ago

I am trying to deploy my ResNet-50 custom trained classification model using pytorch lite library. The model works well on python but giving false predictions on Flutter, always giving high probs for the first class.

I optimized my model using this code provided in the Readme

from torch.utils.mobile_optimizer import optimize_for_mobile
best_model = best_model.to('cpu')

best_model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(best_model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("Fitrzpack_model.pt")

I don't know the exact reason but may this is due to the transformation I have to apply to my images during prediction or something else. I have also provided the transformation code below.

transform = transforms.Compose([transforms.ToPILImage(),
                               transforms.ToTensor(),
                               transforms.Resize((IMG_SIZE, IMG_SIZE)),
                               transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
                     ])

I am also sharing the code I use to predict the model in python

def predict(x):
    img = Image.open(x).convert("RGB")
    img = transform(np.array(img))
    print("shape of transformed image", img.shape)
    img = img.view(1, 3, 224, 224)
    print("shape of reshaped image", img.shape)
    best_model.eval()
    with torch.no_grad():
        if torch.cuda.is_available():
            img = img.cuda()

        out = best_model(img)

        return out.argmax(1).item()

Please help me getting out of this situation @abdelaziz-mahdy

abdelaziz-mahdy commented 1 day ago

Most probably it's a problem with the input sooo please give both your full prediction code for dart and Python

abdelaziz-mahdy commented 23 hours ago

also does it have the same results on android and ios?

MeDenTec commented 8 hours ago

Hi @abdelaziz-mahdy, Thank you so much for your response.

This is how I am loading the model

String pathImageModel =
        "assets/models/best_Fitrzpack_withNorm_extraTransformations_Adam_epoch_22.pt";
    try {
      _imageModel = await PytorchLite.loadClassificationModel(
          pathImageModel, 224, 224, 5,
          labelPath: "assets/labels/label_classification_imageNet.txt");

And this is this how I am making prediction in Dart, I also tried with default mean and std but results were similar.

textToShow = await _imageModel!.getImagePrediction(
        await File(image.path).readAsBytes(),
        mean: [0.485, 0.456, 0.406],
        std: [0.229, 0.224, 0.225]);
    textToShow = "${textToShow ?? ""}, ${inferenceTimeAsString(stopwatch)}";

Here you can also see the output of the prediction list, the first value is always higher hence always prediction first class in every image

[37.45313262939453, 20.606271743774414, 19.992067337036133, 9.70131778717041, 32.07078552246094, -79.898193359375, -94.40465545654297, -99.96098327636719, -95.91107940673828, -102.2623291015625, -94.18521881103516, -104.00025939941406, -97.89167022705078, -102.9541244506836, -101.01667785644531, -101.24066925048828, -103.78877258300781, -96.97942352294922, -94.73454284667969, -104.4057388305664, -98.69512939453125, -96.98098754882812, -95.92950439453125, -99.71725463867188, -36.406681060791016, -95.3742446899414, -93.91866302490234, -93.91046142578125, -97.62044525146484, -78.75592803955078, -96.0226821899414, -100.22040557861328, -108.44017791748047, -81.2840805053711, -96.94139862060547, -92.8824234008789, -102.83641815185547, -104.82340240478516, -91.39096069335938, -103.66374206542969, -104.43321990966797, -106.07219696044922, -106.11819458007812, -91.6676254272461, -99.1351089477539, -102.29850006103516, -99.13842010498047, -90.38143157958984, -106.93729400634766, -90.3895263671875, -93.90796661376953,

Python codes

Below is the python code used for model loading and export


# Model innitialization
resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_ftrs = resnet.fc.in_features
resnet.fc.in_features = nn.Linear(num_ftrs, OUT_CLASSES)

model = deepcopy(resnet)

# Custom trained model loading
model_path ="models/pytorch_skin_types/best_Fitrzpack_withNorm_extraTransformations_Adam_epoch_22.pth"
best_model = model
best_model.load_state_dict(torch.load(model_path,weights_only=True))
best_model.eval()

# Export code used to optimize the model
from torch.utils.mobile_optimizer import optimize_for_mobile
best_model = best_model.to('cpu')

best_model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(best_model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("Fitrzpack_model.pt")

Below is the prediction function used in python

label_index = {"Phototype_I_&_II": 0, "Phototype_III": 1, "Phototype_IV": 2, "Phototype_V": 3, "Phototype_VI": 4}
index_label = {0 :"Phototype_I_&_II", 1 : "Phototype_III", 2 : "Phototype_IV", 3 : "Phototype_V", 4 : "Phototype_VI"}

# Transformations
transform = transforms.Compose([transforms.ToPILImage(),
                               transforms.ToTensor(),
                               transforms.Resize((IMG_SIZE, IMG_SIZE)),
                               transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
                     ])

# prediction function
def predict(x):
    img = Image.open(x).convert("RGB")
    img = transform(np.array(img))
    print("shape of transformed image", img.shape)
    img = img.view(1, 3, 224, 224)
    print("shape of reshaped image", img.shape)
    best_model.eval()
    with torch.no_grad():
        if torch.cuda.is_available():
            img = img.cuda()

        out = best_model(img)

        return out.argmax(1).item()

print(index_label[predict(r"D:\262541-Faizan\misc_projects\face_analyzer\Dataset\Fitrzpack_test\black.jpg")])
print(index_label[predict(r"D:\262541-Faizan\misc_projects\face_analyzer\Dataset\Fitrzpack_test\white.jpg")])
abdelaziz-mahdy commented 7 hours ago

If possible can you provide the model and a test image try it out and make sure whatever fix I do works

MeDenTec commented 5 hours ago

https://drive.google.com/drive/folders/1ncrauuM5S1VcSIDu485kzQgX46RdOH_P?usp=sharing

Sure , you can get the model and the training notebook in above drive folder.

abdelaziz-mahdy commented 4 hours ago

Will check it out, I can't promise I will be able to figure it out or not but will try it