ezmsg-org / ezmsg-sigproc

Timeseries signal processing implementations for ezmsg
MIT License
0 stars 0 forks source link

Efficient and safe mutation of AxisArray #7

Open cboulay opened 4 months ago

cboulay commented 4 months ago

I occasionally run into issues when manipulating AxisArray objects when I fail to properly break the link between the incoming message and the newly generated output. A message will get manipulated in unexpected ways because I end up modifying a shared dictionary or other mutable object. This happens exclusively when using the (generator) transformation methods without the ezmsg runner, because the runner does a deepcopy of every message unless told not to.

So I find myself re-inventing the solution to this problem over and over and I've gotten it wrong a couple times. For example, downsample has a bug that I've fixed in cboulay/working but I haven't pade a PR for yet. Furthermore, it turns out that my most common minimal-code-lines solution is slower than the more straightforward solution.

I wrote a notebook for myself that goes through the issue in detail so hopefully I'll never make the same mistake again, especially if I refer back to this notebook. I'll add it into the repo somewhere eventually. For now it is attached here. optim_axis_array_mods.ipynb.zip

Something surprising came out of this notebook: replace is about twice as slow as creating a new object. I wasn't expecting that.

def full_modify_1(in_msg):
    msg_cls = type(in_msg)
    out_msg = msg_cls(**in_msg.__dict__)
    ax_cls = type(in_msg.axes["time"])
    out_msg.axes = {
        k: (v if k != "time" else ax_cls(gain=in_msg.axes["time"].gain, offset=-88.88))
        for k, v in in_msg.axes.items() if k != "time"
    }
    return out_msg

def full_modify_2(in_msg):
    return replace(
            in_msg,
            axes={
                **in_msg.axes,
                "time": replace(
                    in_msg.axes["time"],
                    offset=-99.99
                )
            }
        )

%timeit [full_modify_1(_) for _ in in_msgs]
%timeit [full_modify_2(_) for _ in in_msgs]
1.91 ms ± 6.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.51 ms ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Or 0.9 vs 2.2 us per message (2048 messages per loop). This doesn't sound like much, but if you have 20-50 nodes that manipulate the axes then this will compound.

I'll make an optimization pass when I'm back from the conference, then I'll set zero_copy=True on all the nodes that I can.

It'll be nice to get rid of the deepcopy because that's about 40 usec per message in this setup!

cboulay commented 3 months ago

"safety" of our mutations is addressed in #8 . Efficiency isn't too bad and we can push that off for now.

cboulay commented 1 month ago

To further address safety, I used frozendict to create AxisArray .axes fields in unit tests in #20 . It seems to work well.