Open md-rifatkhan opened 2 months ago
I'm trying to inference real esrgan, but cant able to get output correctly. I'm using com.microsoft.onnxruntime:onnxruntime-android:1.17.3
com.microsoft.onnxruntime:onnxruntime-android:1.17.3
Model Link: Google Drive
Inference Class:
public class ImageInference { private static final String TAG = "ImageInference"; // Load the image from assets public static Bitmap loadImageFromAssets(Context context, String fileName) throws IOException { Log.d(TAG, "Loading image from assets: " + fileName); InputStream is = context.getAssets().open(fileName); Bitmap image = BitmapFactory.decodeStream(is); if (image != null) { Log.d(TAG, "Image loaded successfully: " + fileName); } else { Log.d(TAG, "Failed to load image: " + fileName); } return image; } // Convert Bitmap to FloatBuffer for ONNX Runtime public static FloatBuffer bitmapToFloatBuffer(Bitmap bitmap, float mean, float std) { int width = bitmap.getWidth(); int height = bitmap.getHeight(); Log.d(TAG, "Preparing to convert bitmap to FloatBuffer. Width: " + width + ", Height: " + height); int[] pixels = new int[width * height]; bitmap.getPixels(pixels, 0, width, 0, 0, width, height); FloatBuffer buffer = FloatBuffer.allocate(width * height * 3); for (final int val : pixels) { buffer.put(((val >> 16) & 0xFF) / 255.f - mean / std); // RED buffer.put(((val >> 8) & 0xFF) / 255.f - mean / std); // GREEN buffer.put((val & 0xFF) / 255.f - mean / std); // BLUE } buffer.flip(); // Prepare buffer for reading Log.d(TAG, "Bitmap successfully converted to FloatBuffer."); return buffer; } public static Bitmap tensorToBitmap(float[][][][] tensor) { // Assume tensor dimensions are [1][3][height][width] int height = tensor[0][0].length; int width = tensor[0][0][0].length; Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888); int[] pixels = new int[width * height]; for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { int r = (int) (tensor[0][0][y][x] * 255); int g = (int) (tensor[0][1][y][x] * 255); int b = (int) (tensor[0][2][y][x] * 255); pixels[y * width + x] = 0xFF000000 | (r << 16) | (g << 8) | b; } } bitmap.setPixels(pixels, 0, width, 0, 0, width, height); return bitmap; } public static float[][][][] runInference(Context context, String modelPath, Bitmap image) throws OrtException, IOException { InputStream is = context.getAssets().open(modelPath); ByteArrayOutputStream bos = new ByteArrayOutputStream(); byte[] buf = new byte[1024]; for (int readNum; (readNum = is.read(buf)) != -1;) { bos.write(buf, 0, readNum); } byte[] modelBytes = bos.toByteArray(); Log.d(TAG, "Starting inference with model: " + modelPath); OrtEnvironment env = OrtEnvironment.getEnvironment(); InputStream modelInputStream = context.getAssets().open(modelPath); OrtSession session = env.createSession(modelBytes, new OrtSession.SessionOptions()); Log.d(TAG, "Model and Session created successfully." ); try { int width = image.getWidth(); int height = image.getHeight(); FloatBuffer inputBuffer = bitmapToFloatBuffer(image, 0f, 1f); OnnxTensor tensor = OnnxTensor.createTensor(env, inputBuffer, new long[]{1, 3, height, width}); OrtSession.Result results = session.run(Collections.singletonMap("input", tensor)); float[][][][] output = (float[][][][]) results.get(0).getValue(); tensor.close(); Log.d(TAG, "Inference completed successfully."); return output; } finally { session.close(); modelInputStream.close(); env.close(); Log.d(TAG, "Cleaned up ONNX resources."); } } }
Main Activity
public class MainActivity extends Activity { private ImageView orginalImageView; private ImageView outputImageView; private TextView textViewResult; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); orginalImageView = findViewById(R.id.orginalImage); outputImageView = findViewById(R.id.outputImage); textViewResult = findViewById(R.id.textViewResult); try { Bitmap image = ImageInference.loadImageFromAssets(this, "LR.png"); orginalImageView.setImageBitmap(image); float[][][][] output = ImageInference.runInference(this, "realesr-general-x4v3-fp32.onnx", image); if (output != null && output.length > 0 && output[0].length > 0 && output[0][0].length > 0) { int height = output[0][0].length; int width = output[0][0][0].length; Bitmap outputImage = ImageInference.tensorToBitmap(output); outputImageView.setImageBitmap(outputImage); // Display the output image textViewResult.setText("Inference complete with output shape: " + "Height " + height + " Width " + width); } else { textViewResult.setText("Inference complete but no valid output!"); } } catch (Exception e) { textViewResult.setText("Inference failed: " + e.getMessage()); e.printStackTrace(); } }
I'm trying to inference real esrgan, but cant able to get output correctly. I'm using
com.microsoft.onnxruntime:onnxruntime-android:1.17.3
Model Link: Google Drive
Inference Class:
Main Activity