ezmsg-org / ezmsg

Pure-Python DAG-based high-performance SHM-backed pub-sub and multi-processing pattern
https://ezmsg.readthedocs.io/en/latest/
MIT License
15 stars 6 forks source link

How to track channel labels and locations? #43

Closed cboulay closed 1 year ago

cboulay commented 1 year ago

If I have a source Unit that has information about channel labels and locations, what is the recommended way to relay that to downstream Units? It's too much information to put into a e.g. SpaceAxis and transmit with every message.

So do I add a META_OUTPUT port to my Unit and transmit it once when I get that information and again whenever it changes? Is there any way to guarantee that this port is transmitted first so that downstream Units will receive the updated metadata before receiving streaming data with the new format?

griffinmilsap commented 1 year ago

In practice, I subclass AxisArray and add custom fields to it that track the extra information. I tend to take it on the nose serializing that info with every message with the understanding that ezmsg is ... very .. fast. Seems inefficient, but it is the vastly simpler solution. Given ezmsg's message cacheing under the hood, the impact is actually quite minimal.

If you really want to do this with a metadata stream, I recommend setting up receivers for your META_OUTPUT that wait until that metadata is received before processing any EEG messages. This is actually a really common pattern (that is unfortunately a little verbose)... Here's a runnable example script:

import typing
import asyncio
from dataclasses import dataclass, replace
from typing import Any, Coroutine

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt

from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.debuglog import DebugLog

@dataclass
class MetadataMessage:
    ch_names: typing.List[str]
    ch_locs: npt.NDArray

class EEGStreamer(ez.Unit):

    OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
    OUTPUT_METADATA = ez.OutputStream(MetadataMessage)

    @ez.publisher(OUTPUT_SIGNAL)
    @ez.publisher(OUTPUT_METADATA)
    async def pub_signal(self) -> typing.AsyncGenerator:

        metadata = MetadataMessage( 
            ch_names = ['a','b','c'], 
            ch_locs = np.ones(3)
        )

        yield self.OUTPUT_METADATA, metadata

        while True:

            yield self.OUTPUT_SIGNAL, AxisArray(
                data = np.ones((100, 3)), 
                dims = ['time', 'ch']
            )
            await asyncio.sleep(1.0)

class EEGModifierState(ez.State):
    incoming_signal: asyncio.Queue[AxisArray]
    incoming_metadata: asyncio.Queue[MetadataMessage]

class EEGModifier(ez.Unit):

    STATE: EEGModifierState

    INPUT_SIGNAL = ez.InputStream(AxisArray)
    INPUT_METADATA = ez.InputStream(MetadataMessage)

    OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

    async def initialize(self):
        self.STATE.incoming_metadata = asyncio.Queue()
        self.STATE.incoming_signal = asyncio.Queue()

    @ez.subscriber(INPUT_SIGNAL)
    async def on_signal(self, msg: AxisArray) -> None:
        self.STATE.incoming_signal.put_nowait(msg)

    @ez.subscriber(INPUT_METADATA)
    async def on_metadata(self, msg: MetadataMessage) -> None:
        self.STATE.incoming_metadata.put_nowait(msg)

    @ez.publisher(OUTPUT_SIGNAL)
    async def modify_eeg(self) -> typing.AsyncGenerator:
        metadata = await self.STATE.incoming_metadata.get()

        while True:
            eeg = await self.STATE.incoming_signal.get()
            yield self.OUTPUT_SIGNAL, replace(eeg, data = eeg.data + metadata.ch_locs)

streamer = EEGStreamer()
modifier = EEGModifier()
log = DebugLog()

ez.run(
    STREAMER = streamer,
    MODIFIER = modifier,
    LOG = log,
    connections = (
        (streamer.OUTPUT_METADATA, modifier.INPUT_METADATA),
        (streamer.OUTPUT_SIGNAL, modifier.INPUT_SIGNAL),
        (modifier.OUTPUT_SIGNAL, log.INPUT)
    )
)
griffinmilsap commented 1 year ago

As a matter of fact, if you do it like this:

@dataclass
class EEGMessage(AxisArray):
    ch_names: typing.List[str]
    ch_locs: npt.ndarray

Under the hood, message passing is implemented with only one copy (with zero-copy reads) for any data type that uses the array API. It might just be faster than you think ;)

griffinmilsap commented 1 year ago

closing for now, but this issue should probably be referenced in a future "Patterns" or FAQ page (#54)