Closed iampawansingh closed 1 year ago
The issue was that I was training the model and serialising it in pytorch - 1.12.1
. However, training the model in pytorch 1.12.1
and serialising it in pytorch - 1.10.0
worked
Although there is a workaround, it definitely is a bug for pytorch >= 1.11. Is this issue reported somewhere in the pytorch main project? If it is, please share the link of the issue.
I have used hello world package and modify it a little to create a classifier based on ResNet18. When this model is deployed on android its performance degrades a lot, in fact it just predict one class for all images, when it does correct prediction on linux machine. I even tried
torch.zeros()
for model prediction and still get the different prediction. Detailed steps are given below:Serialize the trained model, using the steps mentioned in the tutorial
Load the model in the android using
LiteModuleLoader
module = LiteModuleLoader.load(assetFilePath(this, "resnet18_andorid_20221228_v1.ptl"));
Create a tensor of all 0's
Get model prediction.
final Tensor outputTensor = module.forward(IValue.from(mInputTensor)).toTensor();
Get score in the array
final float[] scores = outputTensor.getDataAsFloatArray();` Above gives the score of `[-0.022879722, 0.014740771]
Steps in pytorch are
model_ft(torch.unsqueeze(torch.zeros(3,224,224),0))
Above gives the output as
[ 0.1654, -0.5539]
Environment details:
torch version:
1.12.1.post200
torch vision version:0.13.1
build.gradels
``apply plugin: 'com.android.application'
android { compileSdkVersion 30 buildToolsVersion "29.0.2" defaultConfig { applicationId "org.pytorch.helloworld" minSdkVersion 21 targetSdkVersion 30 versionCode 1 versionName "1.0" } buildTypes { release { minifyEnabled false } } }
dependencies { implementation 'androidx.appcompat:appcompat:1.1.0' implementation ('org.pytorch:pytorch_android_lite:1.12.2'){exclude group: 'org.pytorch:pytorch_android_lite:1.10.0'} implementation ('org.pytorch:pytorch_android_torchvision_lite:1.12.2') {exclude group: 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'} }