pytorch / android-demo-app

PyTorch android examples of usage in applications
1.47k stars 606 forks source link

Classification Model keeps predicting same class in Android #238

Open bebbieyin opened 2 years ago

bebbieyin commented 2 years ago

I have a trained a classification model to classify black and white segmented outputs. The pixels consists of only two values, 0 for black and 1 for white.

Before I run the classification model, the original color image has gone though a segmentation model to get the the black and white bitmap. The segmentation model works fine and the results are pretty similar to the PC version, but when I put the black and white bitmap into the mobile classification model, it keeps predicting the same class.

I tried exporting the black and white bitmap to PNG and ran it on the PC side, it is able to predict the correct class.

I created the black and white bitmap like this:

for (int j = 0; j < scores.length; j++) {
            if (scores[j]>0.4){
                intValues[j] = 0xFFFFFFFF; // white pixels

            }
            else{
                intValues[j] =0xFF000000; // black pixels

            }
        }

I've optimized the model for mobile like this:

script_module = torch.jit.script(pc_classification_model)
optimized_scripted_module = optimize_for_mobile(script_module )
optimized_scripted_module._save_for_lite_interpreter("mobile_classification.ptl")

Library used for Pytorch Android :

implementation 'org.pytorch:pytorch_android_lite:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'