gtbluesky / onnxruntime_flutter

A flutter plugin for OnnxRuntime provides an easy, flexible, and fast Dart API to integrate Onnx models in flutter apps across mobile and desktop platforms.
MIT License
54 stars 13 forks source link

Inferenced output image is weired result, can someone help me? #14

Closed md-rifatkhan closed 1 month ago

md-rifatkhan commented 2 months ago

I'm trying in inference upscaling model but its not working. Someone help me.

Orginal Image: https://github.com/gtbluesky/onnxruntime_flutter/assets/102645154/62188e6d-457f-4783-9b83-7d409cc16eb8

Output Image: https://github.com/gtbluesky/onnxruntime_flutter/assets/102645154/4de92532-d225-4f08-b6d6-b99ad9854135

Some Details:

Model input shape: ['batch_size', 3, 'width', 'height']
Model output shape: ['batch_size', 3, 'width', 'height']

flutter: Is normalized: true
flutter: Image normalized successfully.
flutter: Input tensor created successfully.
flutter: Width: 1080, Height: 1396, Channel: 3
Future<void> inference() async {
    if (selectedImage == null) {
      debugPrint('No image selected');
      return;
    }

    if (selectedImage != null) {
      Float32List? floatData;
      try {
        final normalizedImage = selectedImage!.convert(format: img.Format.float32, numChannels: 3,);
        final pixelData = normalizedImage.buffer.asFloat32List();
        floatData = pixelData;
        debugPrint("Is normalized: ${isNormalized(pixelData)}");
      } catch (e) {
        debugPrint("Error during normalization: $e");
      }

      final shape = [1, 3, selectedImage!.width, selectedImage!.height];

      debugPrint('Image normalized successfully.');

      final inputOrt =
          OrtValueTensor.createTensorWithDataList(floatData!, shape);

      final inputs = {'input': inputOrt};

      debugPrint('Input tensor created successfully.');

      final runOptions = OrtRunOptions();
      final outputs = await ortSession.runAsync(runOptions, inputs);

      inputOrt.release();
      runOptions.release();

      List c = outputs?[0]?.value as List;
      if (c is List<List<List<List<double>>>>) {
        img.Image generatedImage = generateImageFromOutput(c);
        if(!context.mounted) return;
        showDialog(
          context: context,
          builder: (BuildContext context) {
            return Dialog(
              child: SizedBox(
                width: generatedImage.width.toDouble(),
                height: generatedImage.height.toDouble(),
                child: Image.memory(
                  Uint8List.fromList(img.encodePng(generatedImage)),
                  fit: BoxFit.contain,
                ),
              ),
            );
          },
        );
      } else {
        debugPrint("Output is of unknown type");
      }
      outputs?.forEach((element) {
        element?.release();
      });
    }
  }

  img.Image generateImageFromOutput(
      List<List<List<List<double>>>> outputValue) {
    int width = outputValue[0][0].length;
    int height = outputValue[0][0][0].length;
    int channel = outputValue[0].length;

    print("Width: $width, Height: $height, Channel: $channel");

    // Create the image
    img.Image generatedImage = img.Image(width: width, height: height);

    // Set pixel values
    for (int y = 0; y < height; y++) {
      for (int x = 0; x < width; x++) {
        // Extract RGB values from the output tensor data
        int r = (outputValue[0][0][x][y] * 255).toInt().clamp(0, 255);
        int g = (outputValue[0][1][x][y] * 255).toInt().clamp(0, 255);
        int b = (outputValue[0][2][x][y] * 255).toInt().clamp(0, 255);
        // Set pixel value in the generated image
        generatedImage.setPixelRgb(x, y, r, g, b);
      }
    }
    return generatedImage;
  }

  bool isNormalized(Float32List data) {
    for (var pixelValue in data) {
      if (pixelValue < 0 || pixelValue > 1) {
        return false;
      }
    }
    return true;
  }
md-rifatkhan commented 2 months ago

image