microsoft / onnxruntime-inference-examples

Examples for using ONNX Runtime for machine learning inferencing.
MIT License
1.07k stars 312 forks source link

Output image is weired while trying to inference esrgan, Someone help me please. #423

Open md-rifatkhan opened 2 months ago

md-rifatkhan commented 2 months ago

I'm trying to inference real esrgan, but cant able to get output correctly. I'm using com.microsoft.onnxruntime:onnxruntime-android:1.17.3

WhatsApp Image 2024-05-11 at 2 33 10 PM

Model Link: Google Drive

Inference Class:

public class ImageInference {

    private static final String TAG = "ImageInference";

    // Load the image from assets
    public static Bitmap loadImageFromAssets(Context context, String fileName) throws IOException {
        Log.d(TAG, "Loading image from assets: " + fileName);
        InputStream is = context.getAssets().open(fileName);
        Bitmap image = BitmapFactory.decodeStream(is);
        if (image != null) {
            Log.d(TAG, "Image loaded successfully: " + fileName);
        } else {
            Log.d(TAG, "Failed to load image: " + fileName);
        }
        return image;
    }

    // Convert Bitmap to FloatBuffer for ONNX Runtime
    public static FloatBuffer bitmapToFloatBuffer(Bitmap bitmap, float mean, float std) {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        Log.d(TAG, "Preparing to convert bitmap to FloatBuffer. Width: " + width + ", Height: " + height);

        int[] pixels = new int[width * height];
        bitmap.getPixels(pixels, 0, width, 0, 0, width, height);
        FloatBuffer buffer = FloatBuffer.allocate(width * height * 3);

        for (final int val : pixels) {
            buffer.put(((val >> 16) & 0xFF) / 255.f - mean / std); // RED
            buffer.put(((val >> 8) & 0xFF) / 255.f - mean / std);  // GREEN
            buffer.put((val & 0xFF) / 255.f - mean / std);         // BLUE
        }
        buffer.flip(); // Prepare buffer for reading
        Log.d(TAG, "Bitmap successfully converted to FloatBuffer.");
        return buffer;
    }

    public static Bitmap tensorToBitmap(float[][][][] tensor) {
        // Assume tensor dimensions are [1][3][height][width]
        int height = tensor[0][0].length;
        int width = tensor[0][0][0].length;

        Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
        int[] pixels = new int[width * height];

        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int r = (int) (tensor[0][0][y][x] * 255);
                int g = (int) (tensor[0][1][y][x] * 255);
                int b = (int) (tensor[0][2][y][x] * 255);
                pixels[y * width + x] = 0xFF000000 | (r << 16) | (g << 8) | b;
            }
        }

        bitmap.setPixels(pixels, 0, width, 0, 0, width, height);
        return bitmap;
    }

    public static float[][][][] runInference(Context context, String modelPath, Bitmap image) throws OrtException, IOException {

        InputStream is = context.getAssets().open(modelPath);
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        byte[] buf = new byte[1024];
        for (int readNum; (readNum = is.read(buf)) != -1;) {
            bos.write(buf, 0, readNum);
        }
        byte[] modelBytes = bos.toByteArray();

        Log.d(TAG, "Starting inference with model: " + modelPath);

        OrtEnvironment env = OrtEnvironment.getEnvironment();
        InputStream modelInputStream = context.getAssets().open(modelPath);
        OrtSession session = env.createSession(modelBytes, new OrtSession.SessionOptions());
        Log.d(TAG, "Model and Session created successfully." );
        try {
            int width = image.getWidth();
            int height = image.getHeight();
            FloatBuffer inputBuffer = bitmapToFloatBuffer(image, 0f, 1f);
            OnnxTensor tensor = OnnxTensor.createTensor(env, inputBuffer, new long[]{1, 3, height, width});
            OrtSession.Result results = session.run(Collections.singletonMap("input", tensor));
            float[][][][] output = (float[][][][]) results.get(0).getValue();
            tensor.close();
            Log.d(TAG, "Inference completed successfully.");
            return output;
        } finally {
            session.close();
            modelInputStream.close();
            env.close();
            Log.d(TAG, "Cleaned up ONNX resources.");
        }
    }
}

Main Activity

public class MainActivity extends Activity {

    private ImageView orginalImageView;
    private ImageView outputImageView;
    private TextView textViewResult;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        orginalImageView = findViewById(R.id.orginalImage);
        outputImageView = findViewById(R.id.outputImage);
        textViewResult = findViewById(R.id.textViewResult);

        try {
            Bitmap image = ImageInference.loadImageFromAssets(this, "LR.png");
            orginalImageView.setImageBitmap(image);
            float[][][][] output = ImageInference.runInference(this, "realesr-general-x4v3-fp32.onnx", image);
            if (output != null && output.length > 0 && output[0].length > 0 && output[0][0].length > 0) {
                int height = output[0][0].length;
                int width = output[0][0][0].length;

                Bitmap outputImage = ImageInference.tensorToBitmap(output);
                outputImageView.setImageBitmap(outputImage);  // Display the output image
                textViewResult.setText("Inference complete with output shape: " + "Height " + height + " Width "  + width);
            } else {
                textViewResult.setText("Inference complete but no valid output!");
            }
        } catch (Exception e) {
            textViewResult.setText("Inference failed: " + e.getMessage());
            e.printStackTrace();
        }
    }
md-rifatkhan commented 2 months ago

image