janreising / astroCAST

Package to classify astrocytic calcium events
GNU General Public License v3.0
0 stars 0 forks source link

add caching to `Events.get_extended_events` #65

Open janreising opened 1 year ago

janreising commented 1 year ago

If we would add the wrapper to the function, we would have to calculate the hash value for a large array, which might take a long time and might require loading the array to memory. Both of which are not desirable.

# @wrapper_local_cache
    def get_extended_events(self, video=None, dtype=np.half, extend=-1,
                            return_array=False, in_place=False,
                            normalization_instructions=None, show_progress=True,
                            memmap_path=None, save_path=None, save_param={}):
janreising commented 1 year ago

The previous approach was optimized to use the built-in hashing of dask.array (array.name). While this leads to significantly faster results, the hash is not always the same between different kernels. Currently it is not clear if that is due to the dask package, or our implementation. This will have to be revisited later.

@pytest.mark.parametrize("lazy", [True, False])
def test_extension_cache(self, lazy, shape=(50, 100, 100)):

    with tempfile.TemporaryDirectory() as tmpdir:

        # Create dummy data
        event_dir = create_sim_data(Path(tmpdir), shape=shape)

        # load events - 1
        events_1 = Events(event_dir, lazy=lazy, cache_path=event_dir)
        video = Video(events_1.event_map)

        t0 = time.time()
        trace = events_1.get_extended_events(video=video, extend=(2, 2), in_place=True)
        d1 = time.time() - t0

        # load events - 2
        events_2 = Events(event_dir, lazy=lazy, cache_path=event_dir)
        video = Video(events_2.event_map)

        t0 = time.time()
        trace_2 = events_2.get_extended_events(video=video, extend=(2, 2))
        events_2.events = trace_2
        d2 = time.time() - t0

        assert d2 < d1, f"caching is taking too long: {d2} >= {d1}"
        assert hash(events_1) == hash(events_2)

Recaching results:

WARNING  root:analysis.py:957 da.Array.name: array-53cdd4bfb52963a984b56e0ec6abef2d
WARNING  root:helper.py:123 hash_string: get_extended_events_3290801508159602123_extend-(2, 2)_video-7205464800592760953_
WARNING  root:analysis.py:957 da.Array.name: array-132e27ed6d9c20459b0980bda17c8024
WARNING  root:helper.py:123 hash_string: get_extended_events_3290801508159602123_extend-(2, 2)_video--8314633131423857843_