abdelaziz-mahdy / pytorch_lite

flutter package to help run pytorch lite models classification and YoloV5 and YoloV8.
MIT License
51 stars 22 forks source link

Error using Isolates (Multithreading) #55

Open tnghieu opened 11 months ago

tnghieu commented 11 months ago

Is it possible to use pytorch_lite in isolates for running inference? I'm trying to split up the work on separate threads to speed up the process.

[ERROR:flutter/runtime/dart_isolate.cc(1097)] Unhandled exception: Bad state: The BackgroundIsolateBinaryMessenger.instance value is invalid until BackgroundIsolateBinaryMessenger.ensureInitialized is executed.

0 BackgroundIsolateBinaryMessenger.instance (package:flutter/src/services/_background_isolate_binary_messenger_io.dart:27:7)

1 _findBinaryMessenger (package:flutter/src/services/platform_channel.dart:145:42)

2 BasicMessageChannel.binaryMessenger (package:flutter/src/services/platform_channel.dart:192:56)

3 BasicMessageChannel.send (package:flutter/src/services/platform_channel.dart:206:38)

4 ModelApi.getImagePredictionListObjectDetection (package:pytorch_lite/pigeon.dart:319:52)

5 ModelObjectDetection.getImagePredictionList (package:pytorch_lite/pytorch_lite.dart:525:30)

6 ModelObjectDetection.getImagePrediction (package:pytorch_lite/pytorch_lite.dart:455:52)

7 _InferencePageState.processImageBatch (package:sharkeye/inference_page.dart:120:38)

8 _delayEntrypointInvocation. (dart:isolate-patch/isolate_patch.dart:300:17)

9 _RawReceivePort._handleMessage (dart:isolate-patch/isolate_patch.dart:184:12)

ChatGPT's response to this error: 'The PytorchLite plugin is attempting to use platform channels (binary messenger) inside an isolate, which is not supported by default in Flutter.'

abdelaziz-mahdy commented 11 months ago

right now the code uses background by default

searching, i found this https://pub.dev/packages/flutter_isolate , also let me know if its faster

tnghieu commented 11 months ago

I was able to implement Isolates with the native dart isolates.

It is indeed faster in the sense that work is delegated across these isolates by a factor of N isolates.

However, due to the nature of how many images I'm processing through inference, I'm running into memory issues:

* thread #7, queue = 'com.apple.root.default-qos', stop reason = EXC_RESOURCE (RESOURCE_TYPE_MEMORY: high watermark memory limit exceeded) (limit=3072 MB)
    frame #0: 0x00000001055bb50c Runner`void c10::function_ref<void (char**, long long const*, long long, long long)>::callback_fn<at::native::DEFAULT::VectorizedLoop2d<at::native::DEFAULT::direct_copy_kernel(at::TensorIteratorBase&)::$_1::operator()() const::'lambda12'()::operator()() const::'lambda'(float), at::native::DEFAULT::direct_copy_kernel(at::TensorIteratorBase&)::$_1::operator()() const::'lambda12'()::operator()() const::'lambda'(at::vec::DEFAULT::Vectorized<float>)> >(long, char**, long long const*, long long, long long) + 496
Runner`c10::function_ref<void (char**, long long const*, long long, long long)>::callback_fn<at::native::DEFAULT::VectorizedLoop2d<at::native::DEFAULT::direct_copy_kernel(at::TensorIteratorBase&)::$_1::operator()() const::'lambda12'()::operator()() const::'lambda'(float), at::native::DEFAULT::direct_copy_kernel(at::TensorIteratorBase&)::$_1::operator()() const::'lambda12'()::operator()() const::'lambda'(at::vec::DEFAULT::Vectorized<float>)> >:

Looking around old issues and from flutter_tflite issues, it appears as though inference is currently being run via CPU instead of GPU, resulting in slow inference of around 200-400ms per image.

  1. Are you still looking into how to implement a GPU solution?
  2. Do you have any suggestions on running inference on lots of images without running into memory issues? I have tried using a queue to limit the amount of images required for each isolate to process.
abdelaziz-mahdy commented 11 months ago

I would like to see how you implement the logic for the isolates, maybe I can use them in the package to improve speed, if you want to do a pr, this will be much appreciated

Also for GPU it's not supported by pytorch_mobile which is why I can't use it

For memory when I see the implementation of the isolates I may be able to reduce memory load, but not much

tnghieu commented 11 months ago

Still troubleshooting with isolates. I notice that PyTorch_lite is using Computer library to spawn 2 workers through the ImageUtilsIsolate (default).

Therefore if I have the main Ui thread (1), 2 of my own threads (2) and therefore 4 Computer worker threads (4), I hit a total of 7, which I believe is more than the processors of an iPhone.

abdelaziz-mahdy commented 11 months ago

From my understanding in processors, if the process is not running another task takes command,

So it should not be a problem, even if you are using 100 that's won't cause any problems, but you wont see much of increase in performance beyond a certain point

abdelaziz-mahdy commented 11 months ago

I really don't understand the way you were able to run the package inside threads!

It should run as asynchronous task so that does block ui thread, so it should not be any faster if used in threads, this is why I wanted an example to test on

tnghieu commented 11 months ago

Can you expand on that?

Could I not spawn an isolate, load the ModelObjectDetection class into it, and then have separate instances of the model to run predictions with?

  Future<void> initializeIsolates() async {
    print('work items len: ${workItems.length}');

    List<Future> isolateFutures =
        []; // This will hold futures of all the isolates.

    for (List<int> batch in workItems) {
      final ReceivePort receivePort = ReceivePort();
      Completer<void> completer =
          Completer<void>(); // Completer for this batch.

      RootIsolateToken rootIsolateToken = RootIsolateToken.instance!;

      Isolate isolate = await Isolate.spawn(isolateWork, {
        'sendPort': receivePort.sendPort,
        'batch': batch,
        'videoPath': videoPath,
        'rootIsolateToken': rootIsolateToken,
        'model': objectDetectionModel,
      });

      // Now we listen to the messages from the isolate.
      receivePort.listen(
        (message) {
          if (message is Map<int, List<ResultObjectDetection>>) {
            detectionsMap.addAll(message);
            setState(() {});
            if (message.length < 10) {
              receivePort.close();
              isolate.kill(priority: Isolate.immediate);
              completer.complete(); // Complete the completer when done.
            }
          } else {
            print('RECEIVE PORT WRONG DATA TYPE??');
            completer.completeError('Wrong data type received');
          }
        },
        onDone: () {
          if (!completer.isCompleted) {
            completer.complete(); // Complete if the receive port is closed.
          }
        },
        onError: (error) {
          completer.completeError(error); // Complete with error on any error.
        },
      );

      isolateFutures.add(
          completer.future); // Add this batch's completer's future to the list.
    }

    // Wait for all isolates to complete.
    await Future.wait(isolateFutures);
    print('All isolates completed their work.');
  }

  static Future<void> isolateWork(Map<String, dynamic> args) async {
    SendPort sendPort = args['sendPort'];
    List<int> workItems = args['batch'];
    String videoPath = args['videoPath'];
    ModelObjectDetection model = args['model'];
    RootIsolateToken rootIsolateToken = args['rootIsolateToken'];

    BackgroundIsolateBinaryMessenger.ensureInitialized(rootIsolateToken);

    for (int i = 0; i < workItems.length; i += 10) {
      int start = i;
      int end = i + 10 > workItems.length ? workItems.length : i + 10;
      List<int> batch = workItems.sublist(start, end);
      Map<int, List<ResultObjectDetection>> result = await processImages(
        batch,
        videoPath,
        model,
      );
      sendPort.send(result);
    }
  }

  static Future<Map<int, List<ResultObjectDetection>>> processImages(
    List<int> frames,
    String videoPath,
    ModelObjectDetection model,
  ) async {
    Map<int, List<ResultObjectDetection>> results = {};
    for (int frame in frames) {
      Uint8List? image = await VideoThumbnail.thumbnailData(
        video: videoPath,
        imageFormat: ImageFormat.JPEG,
        timeMs: frame,
      );
      if (image != null) {
        List<ResultObjectDetection> detections =
            await model.getImagePrediction(image);

        results[frame] = detections;
      }
      image = null;
    }

    return results;
  }
abdelaziz-mahdy commented 11 months ago

Is using this better than just using direct inference without isolate?

I think it should be the same, didn't test but I am just saying the logic

tnghieu commented 11 months ago

My intention was to split the work of inference between the threads.

For a given set of 1000 images, give isolate 1: 333, isolate 2: 333, isolate 3: 334 to run in parallel. My expectations would be that it would take 334 images worth of inference time to process. Is this logic incorrect?

abdelaziz-mahdy commented 11 months ago

I don't know,

But let's think of it in another way, if pytorch uses all cores to run inference then this logic is not correct since the inference will be using all cores for each one so it may take the time of 900 images instead of 1000

If pytorch uses 1 core then yes this logic is correct,

Using my package default async tasks should launch isolates by default, so I don't think using isolates to distribute to the tasks is gaining you any performance

Like I think running the 1000 as a batches of 3 images using package await will give same performance or better

tnghieu commented 11 months ago

I see this function in the package: getImagePredictionListObjectDetection(List imageAsBytesList), is this intended to be used for a list of images?

abdelaziz-mahdy commented 11 months ago

No this intended for camera images

abdelaziz-mahdy commented 11 months ago

Keep in mind I don't want to provide a function that takes list of images since loading 1000 images will most probably have problems with memory