pytorch / android-demo-app

PyTorch android examples of usage in applications
1.46k stars 604 forks source link

Android App crash after click button #184

Open navidnayyem opened 2 years ago

navidnayyem commented 2 years ago

We are working on a custom dataset model and also trained our model in a colab notebook. From the colab notebook, we got our model weights file in the outputs which is model_final.ptl after finishes training. Now, we are trying to integrate model_final.ptl file into the pytorch android demo app (Github Link: https://github.com/pytorch/android-demo-app/tree/master/ImageSegmentation). The problem is that first of all, the apps runs fine but when I click on the button after sometime ,the app closes and shut down. I then watch the logcat of Android Studio and see some error messages which I attached it below. Actually, we want a solution for this error. @jeffxtang

240962362_437058097695967_5353504230934603878_n 241311097_438417897561976_4468900962703030269_n

jeffxtang commented 2 years ago

As the error message shows, your custom model outputs the data in tensor type (2) but line 124 converts the output to a dictionary type (13):

  private static final int TYPE_CODE_TENSOR = 2;
  private static final int TYPE_CODE_BOOL = 3;
  private static final int TYPE_CODE_LONG = 4;
  private static final int TYPE_CODE_DOUBLE = 5;
  private static final int TYPE_CODE_STRING = 6;

  private static final int TYPE_CODE_TUPLE = 7;
  private static final int TYPE_CODE_BOOL_LIST = 8;
  private static final int TYPE_CODE_LONG_LIST = 9;
  private static final int TYPE_CODE_DOUBLE_LIST = 10;
  private static final int TYPE_CODE_TENSOR_LIST = 11;
  private static final int TYPE_CODE_LIST = 12;

  private static final int TYPE_CODE_DICT_STRING_KEY = 13;

Try changing lines 124-128 to final Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor();

ir1979 commented 2 years ago

make sure you have loaded and saved the correct model. For image segmentation, I used the following code: model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True, progress=True)

ir1979 commented 2 years ago

Replace the generated file deeplabv3_scripted_optimized.ptl with deeplabv3_scripted.ptl and use it in your application, i.e., Line 110 of MainActivity.java has to be: mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "deeplabv3_scripted.ptl"));

Shahrullo commented 2 years ago

Try to upgrade your gradle version.