gtbluesky / onnxruntime_flutter

A flutter plugin for OnnxRuntime provides an easy, flexible, and fast Dart API to integrate Onnx models in flutter apps across mobile and desktop platforms.
MIT License
68 stars 12 forks source link

Need help to inference image for upscaling #19

Closed md-rifatkhan closed 4 months ago

md-rifatkhan commented 4 months ago

image

I've tried a upscale s simple red image, but output is weired,

    Future<void> loadModel() async {
    print(ortEnv.availableProviders());
    final byteData =
        await rootBundle.load('assets/realesr-general-x4v3-fp32.onnx');
    final buffer = byteData.buffer;

    OrtSessionOptions options = OrtSessionOptions();
    session = OrtSession.fromBuffer(buffer.asUint8List(), options);
  }

  Future<void> runInference() async {
    OrtRunOptions options = OrtRunOptions();

    img.Image image = img.Image(256, 256);

    // Sample image with red color
    int redColor = img.getColor(255, 255, 0);
    for (int y = 0; y < image.height; y++) {
      for (int x = 0; x < image.width; x++) {
        image.setPixel(x, y, redColor);
      }
    }
    orginalImage = image;

    List<int> imageBytesList = image.getBytes();
    Float32List normalizedInput = Float32List.fromList(imageBytesList
        .map((b) => b / 255.0)
        .toList());
    Float32List inputTensor = Float32List(image.height * image.width * 3);
    int i = 0;
    for (int h = 0; h < image.height; h++) {
      for (int w = 0; w < image.width; w++) {
        inputTensor[i++] = normalizedInput[h * image.width * 3 + w * 3 + 0];
        inputTensor[i++] = normalizedInput[h * image.width * 3 + w * 3 + 1];
        inputTensor[i++] = normalizedInput[h * image.width * 3 + w * 3 + 2];

      }
    }
    final createInput = OrtValueTensor.createTensorWithDataList(
        inputTensor, [1, 3, image.height, image.width]);

    final result = session.runAsync(options, {'input': createInput});
    var outputResult = await result!;
    List outputTensor = outputResult[0]?.value as List;

    if (outputTensor is List<List<List<List<double>>>>) {

      List<List<List<List<double>>>> tensorData = outputTensor;

      int height = tensorData[0][0].length;
      int width = tensorData[0][0][0].length;
      int channel = tensorData[0].length;

      Uint8List outputImageData = Uint8List(width * height * 4);
      int i = 0;
      for (int h = 0; h < height; h++) {
        for (int w = 0; w < width; w++) {
          double red = (tensorData[0][0][h][w] * 255).clamp(0, 255);
          double green = (tensorData[0][1][h][w] * 255).clamp(0, 255);
          double blue = (tensorData[0][2][h][w] * 255).clamp(0, 255);
          outputImageData[i++] = red.toInt();
          outputImageData[i++] = green.toInt();
          outputImageData[i++] = blue.toInt();
          outputImageData[i++] = 255;
        }
      }

      setState(() {
        outputImage = img.Image.fromBytes(
            width, height, outputImageData);
        ScaffoldMessenger.of(context).showSnackBar(
          SnackBar(
            content: Text('Output Image Dimensions: $width x $height'),
            duration: const Duration(seconds: 2),
          ),
        );
      });
    }
  }