pytorch / android-demo-app

PyTorch android examples of usage in applications
1.47k stars 607 forks source link

Performance mis-match when model is deployed on android. #296

Closed iampawansingh closed 1 year ago

iampawansingh commented 1 year ago

I have used hello world package and modify it a little to create a classifier based on ResNet18. When this model is deployed on android its performance degrades a lot, in fact it just predict one class for all images, when it does correct prediction on linux machine. I even tried torch.zeros() for model prediction and still get the different prediction. Detailed steps are given below:

  1. Serialize the trained model, using the steps mentioned in the tutorial

    
      model_ft.eval()
      example = torch.rand(1, 3, 224, 224)
      traced_script_module = torch.jit.trace(model_ft, example)
      traced_script_module_optimized = optimize_for_mobile(traced_script_module)
      traced_script_module_optimized._save_for_lite_interpreter("resnet18_andorid_20221228_v1.ptl")```
  2. Load the model in the android using LiteModuleLoader

    module = LiteModuleLoader.load(assetFilePath(this, "resnet18_andorid_20221228_v1.ptl"));

  3. Create a tensor of all 0's

    
    Tensor mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, 224, 224});```
  4. Get model prediction.

final Tensor outputTensor = module.forward(IValue.from(mInputTensor)).toTensor();

  1. Get score in the array final float[] scores = outputTensor.getDataAsFloatArray();` Above gives the score of `[-0.022879722, 0.014740771]

  2. Steps in pytorch are

model_ft(torch.unsqueeze(torch.zeros(3,224,224),0))

Above gives the output as [ 0.1654, -0.5539]

Environment details:

torch version: 1.12.1.post200 torch vision version: 0.13.1

build.gradels

``apply plugin: 'com.android.application'

android { compileSdkVersion 30 buildToolsVersion "29.0.2" defaultConfig { applicationId "org.pytorch.helloworld" minSdkVersion 21 targetSdkVersion 30 versionCode 1 versionName "1.0" } buildTypes { release { minifyEnabled false } } }

dependencies { implementation 'androidx.appcompat:appcompat:1.1.0' implementation ('org.pytorch:pytorch_android_lite:1.12.2'){exclude group: 'org.pytorch:pytorch_android_lite:1.10.0'} implementation ('org.pytorch:pytorch_android_torchvision_lite:1.12.2') {exclude group: 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'} }



Any pointer on how to fix this will be really helpful.
iampawansingh commented 1 year ago

The issue was that I was training the model and serialising it in pytorch - 1.12.1. However, training the model in pytorch 1.12.1 and serialising it in pytorch - 1.10.0 worked

pocorall commented 1 year ago

Although there is a workaround, it definitely is a bug for pytorch >= 1.11. Is this issue reported somewhere in the pytorch main project? If it is, please share the link of the issue.