Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.78k stars 1.07k forks source link

Add real-time inference support in bundle #8134

Open Nic-Ma opened 1 week ago

Nic-Ma commented 1 week ago

Hi @ericspod , @KumoLiu ,

Recently, I got more and more feature requests to run bundle for real-time inference in MONAI Label, MONAI Deploy, and NVIDIA NIMs, etc. There are 2 main blockers to support it:

  1. Our current inference examples are for batch inference, for example: https://github.com/Project-MONAI/model-zoo/blob/dev/models/spleen_ct_segmentation/configs/inference.json We have lazy-instantiation for all the components in the config and pre-define all the datalist in the config. But for real-time inference, we should instantiate all the python components defined in the config, and keep the model idle in GPU, then waiting for input data request. Our current design can't change the config content once instantiated, because we do deep copy during parsing: https://github.com/Project-MONAI/MONAI/blob/dev/monai/bundle/config_parser.py#L347 I made a very hacky method to replace input data, it works but obviously not general for all bundles:
    ConfigWorkflow.parser.ref_resolver.items["dataset"].config["data"][0] = input_data
  2. [Optional] 3rd party applications usually have their own input and output pipelines, they need to remove or replace the LoadImage and SaveImage transforms in the bundle config. We only have MERGE_KEY, missing the delete key: https://github.com/Project-MONAI/MONAI/blob/dev/monai/bundle/utils.py#L252

Could you please help investigate this problem and make an ideal solution together? It can be an important feature for MONAI 1.5.

Thanks in advance.

ericspod commented 4 days ago

Hi @Nic-Ma I've had a chance to put some ideas together and see if we can get our workflow classes to play along with them. My thought was to define a dataset which can represent the input source of streaming data. This would be fed by whatever the actual source is, eg. something capturing frames from a video source or something reading from a sensor. This dataset would be iterable so inherently can represent an infinite series of data items, and when not receiving anything can just wait and not yield anything. A second part would be a transform in the postprocessing part of the workflow which acts as the sink or consumer of what's being predicted by the network. This would hand off tensors to whatever the actual sink for the stream would be, and would do what conversions were needed. I have put these ideas together in the following demo which should run as a standalone script file:

from queue import Empty, Queue
import time
from threading import Thread, RLock
from monai.engines import SupervisedEvaluator
from monai.transforms import Transform
from monai.utils.enums import CommonKeys
import torch

class IterableBufferDataset(torch.utils.data.IterableDataset):
    """Defines a iterable dataset using a Queue object to permit asynchronous additions of new items, eg. frames."""

    STOP = object()  # stop sentinel

    def __init__(self, buffer_size: int = 0, timeout: float = 0.01):
        super().__init__()
        self.buffer_size = buffer_size
        self.timeout = timeout
        self.buffer: Queue = Queue(self.buffer_size)
        self._is_running = False
        self._lock = RLock()

    @property
    def is_running(self):
        with self._lock:
            return self._is_running

    def add_item(self, item):
        """
        The idea is that the source of the streaming data would add items here and these would be consumed by the
        engine immediately. The engine's `run` method would be running in the main or some other thread separate from
        the source, eg. something reading from port or from a device which puts individual video frames here.
        """
        self.buffer.put(item, timeout=self.timeout)

    def stop(self):
        with self._lock:
            self._is_running = False

    def __iter__(self):
        """
        This will continually attempt to get an item from the queue until STOP is received or stop() called.
        """
        with self._lock:
            self._is_running = True

        try:
            while self.is_running:  # checking exit condition prevents deadlock
                try:
                    item = self.buffer.get(timeout=self.timeout)

                    if item is IterableBufferDataset.STOP:  # stop looping when sentinel received
                        break

                    yield item
                except Empty:
                    pass  # queue was empty this time, try again
        finally:
            self.stop()

ds = IterableBufferDataset()

def stream_source():
    """Adds items into the dataset as if this were an asynchronous source of streaming data."""
    for i in range(1, 6):
        ds.add_item(torch.full((i,), i))
        time.sleep(0.5)

    ds.add_item(ds.STOP)

t = Thread(target=stream_source)
t.start()

class StreamSink(Transform):
    """Represents a sink of streaming data, this postprocess transform will consume from the network as it arrives."""

    def __call__(self, data):
        print("Stream sink:", data[CommonKeys.PRED])

class Trace(torch.nn.Module):
    """Simple test network which just squares input."""

    def forward(self, x):
        print("Net input:", x)
        return x**2

# The engine should be set to not decollate and use the dataset directly as the loader, this will have the fewest things
# being done to the data between when it arrives and when it goes to the sink.
evaluator = SupervisedEvaluator(
    device="cpu", val_data_loader=ds, network=Trace(), epoch_length=1, decollate=False, postprocessing=StreamSink()
)

# monai/engines/workflow.py:140 makes a check against epoch_length which isn't valid, epoch_length can be left as None
# to permit arbitrary length data loading
evaluator.state.epoch_length = None

evaluator.run()
ericspod commented 4 days ago

Also the codebase changes in #8146 may not be bad additions anyway even if not needed for streaming.

Nic-Ma commented 3 days ago

Hi @ericspod ,

Thanks for your sharing. I also built a very similar prototype locally last week for this streaming method, I tried to make a subclass inheriting from both MONAI IterableDataset and python Queue. We can involve many people for the discussion.

Thanks.

ericspod commented 17 hours ago

This is a good start for some use cases so we should continue this development for 1.5. We should consider what use cases people who want to do streaming inference have, real time or not. Some will come with bundles or other applications written with MONAI and so will want to use MONAI transforms, engine classes, and other things to do this work, which is what we're targeting here. Others will want to do things much faster with more integrated infrastructure like TensorRT or with more low level streaming software like gstreamer. For these use cases it means bundles with engine classes and everything else they use may not be suitable, so users will have to pull the components apart to fit into this other workflow somehow. Some users will want to use Holoscan, DeepStream, or some other technology as well. So I'd say we're off to a good start but need to discuss what the other use cases are.

Nic-Ma commented 16 hours ago

@ericspod Make sense to me. Add @MMelQin to the discussion as Ming is expert on model inference and deployment.

Thanks.