ezmsg-org / ezmsg

Pure-Python DAG-based high-performance SHM-backed pub-sub and multi-processing pattern
https://ezmsg.readthedocs.io/en/latest/
MIT License
11 stars 5 forks source link

Add generator pattern example #55

Closed pperanich closed 10 months ago

pperanich commented 10 months ago

This example demonstrates a pattern in ezmsg where computation is encapsulated in Python generators, providing a structure that is both reusable in offline processing and integratable with ezmsg. Adding in response to https://github.com/iscoe/ezmsg/issues/54

cboulay commented 10 months ago

I'm very happy to see this example because it serves a specific design need quite well: Unit core logic can come from a library of documented and well-tested functions, independent of the composability -- and inherent unverifiability -- that ezmsg provides. This is great! This might provide a path to using ezmsg for prototyping then using the same core algorithms in something more regulated and locked down.

Can we get consumer, compose and Gen in a module, not just an example?

cboulay commented 10 months ago

About the "offline analysis" motivation, while I get it and I agree in principle, do you think this approach will work in a real-world analysis? A 1-channel 100-sample dataset can be analyzed in a single chunk so the output = pipeline(input) is certainly simpler. But what if we're dealing with data that can't be processed in a single chunk, like an hours-long 256-channel 30 khz data file?

Could you please replace one of the trivial pow or add examples with something (only slightly) more complicated that requires state, like a filter or a counter? Is the state maintained between calls to pipeline(chunk)? If that's true, then I guess offline analysis of a very large file can just be done with a loop that iterates over the chunks in the file.

That being said, I think the parallelization / asynchronicity that the ezmsg graph provides would be very beneficial in an offline analysis, if it were possible to tune the ChunkwiseFileReader(ez.Unit) to push chunks at the optimal rate for the pipeline. But that's a different topic.

pperanich commented 10 months ago

Re "offline analysis":

I've used this pattern thoroughly in a library for processing digital holography data. In terms of data throughput, a 5-min collect is on the order of 20GB. The data in this case is typically images streaming at ~10,000 fps of 64x64 pixels. When processing the data offline, typically in a notebook, the data is read in and streamed through the pipeline in batches with something like the following:

    batches = list(tqdm(
        map(
            pipeline,
            stream_holograms(dataset)
        ),
        total=num_batches,
    ))

where stream_holograms returns an iterator for the batches of data, reading the underlying data file in chunks.

I can include an example of a FIR filter below to explain some other benefits:

@consumer
@cupy_free_on_close
def firfilter(
    b: npt.NDArray, *, compute_mode: ComputeMode = ComputeMode.NUMPY,
) -> Generator[BatchedHolograms, BatchedHolograms, None]:
    xp = get_module(compute_mode)
    fftconvolve = get_fftconvolve(compute_mode)

    axis_arr_in: BatchedHolograms = BatchedHolograms(np.array([]), [""])
    axis_arr_out: BatchedHolograms = BatchedHolograms(np.array([]), [""])

    ind_zi = None
    ind_out = None
    ind_zf = None
    weights = None
    zi = None

    while True:
        axis_arr_in = yield axis_arr_out
        arr_in = axis_arr_in.data

        if len(b) <= 0:
            axis_arr_out = axis_arr_in
            continue

        if (
            weights is None
            or zi is None
            or ind_zi is None
            or ind_out is None
            or ind_zf is None
        ):
            (weights, zi, ind_zi, ind_out, ind_zf,) = construct_weights(arr_in.shape, b)
            weights = xp.array(weights)
            zi = xp.array(zi)

        arr_out_full = fftconvolve(arr_in, weights, axes=0)
        arr_out_full[ind_zi] += zi
        arr_out = arr_out_full[ind_out]
        zi = arr_out_full[ind_zf]

        axis_arr_out = replace(axis_arr_in, data=arr_out)

def construct_weights(data_shape, b):
    batch_sz = data_shape[0]
    channels = data_shape[1:]
    weights = np.expand_dims(np.array(b), tuple(i + 1 for i, _ in enumerate(channels)))

    zi = scipy.signal.lfilter_zi(b, [1.0])
    zi_0 = zi.shape[0]
    zi_shape = list(channels)
    zi = np.repeat(zi, np.prod(zi_shape))
    zi = zi.reshape([zi_0] + zi_shape)

    ind = zi.ndim * [slice(None)]
    ind[0] = slice(zi_0)
    ind_zi = tuple(ind)
    ind[0] = slice(batch_sz)
    ind_out = tuple(ind)
    ind[0] = slice(batch_sz, None)
    ind_zf = tuple(ind)
    return weights, zi, ind_zi, ind_out, ind_zf

A couple nice things about the above:

cboulay commented 10 months ago

Thank you for the added details and examples!

gen_to_unit looks very useful.

Can this pattern be extended to multi-input or multi-output functions? i.e., steps 1 and 2 both feed into 3? I don't expect so, and it seems out-of-scope. I don't have a specific use-case yet and maybe when I sit down and work it out, there might be a way to write the data iterator and pipeline so everything travels together.

pperanich commented 10 months ago

I've thought about use with multi-input and multi-output before, but didn't really land on a particular pattern. I think there are two potential methods you could use here:

  1. Passing in the multiple inputs as an iterable (or similar for outputs)
    
    def multi_input_output() -> Generator[Tuple, Tuple, None]:
    axis_arr_out_1 = AxisArray(np.array([]), dims=[""])
    axis_arr_out_2 = AxisArray(np.array([]), dims=[""])
    while True:
        (axis_arr_in_1, axis_arr_in_2) = yield (axis_arr_out_1, axis_arr_out_2)
        axis_arr_out_1 = axis_arr_in_2
        axis_arr_out_2 = axis_arr_in_1

gen = multi_input_output() next(gen)

out_1, out_2 = gen.send((AxisArray(np.arange(10), dims=[""]), AxisArray(np.arange(20), dims=[""])))


2. Differentiating based on message type within the generator
```python
def diff_types() -> Generator[Any, Any, None]:
    message_out = None
    while True:
        message_in = yield message_out
        if type(message_in) is AxisArray:
            # Do some computation specific to AxisArray
            message_out = ...
        elif type(message_in) is Metadata:
            # Store new metadata, etc...
            message_out = ...

Given these two approaches, I lean towards the second because you could still make use of the gen-to-unit or sub-classing approach as it generalizes well. All the logic still resides in the generator.

cboulay commented 10 months ago

Hi @pperanich , I like this very much and I started using it in my own code (see #65). Do you think it would be possible to move consumer and compose out of the examples and into ezmsg core?