scverse / spatialdata

An open and interoperable data framework for spatial omics data
https://spatialdata.scverse.org/
BSD 3-Clause "New" or "Revised" License
196 stars 34 forks source link

Optimize writing performance for `MultiscaleSpatialImage` #577

Open LucaMarconato opened 3 weeks ago

LucaMarconato commented 3 weeks ago

As observed by @ArneDefauw, unnecessary loading operations are performed when calling write_multiscale() on a list of lazy tensors derived from to_multiscale().

Optimizing the order in which the data is computed and written to disk, so to avoid the loading of the same chunks 2+ times, would probably lead to a drastic performance improvement, up to 10-fold.

ArneDefauw commented 3 weeks ago

I include a minimal example to reproduce the observed behaviour.

If arr.persist() is called, the code completes in ~10s, but if arr.persist() is commented, the code compeletes in ~50 s (i.e. 10s for each scale -> some_function is called 5 times). As an alternative to using .persist(), writing to a zarr store, and then loading it back, evidently 'solves' the problem in a similar way.


import os
import tempfile
import time

import dask.array as da
import numpy as np
import spatialdata
from spatialdata.datasets import blobs

sdata = blobs()

start = time.time()

with tempfile.TemporaryDirectory() as temp_dir:
    sdata.write(os.path.join(temp_dir, "sdata_blobs_dummy.zarr"))

    def _some_function(arr):
        arr = arr * 2
        time.sleep(10)
        return arr

    arr = sdata["blobs_image"].data

    arr = da.map_blocks(_some_function, arr, dtype=float, meta=np.array((), dtype=float))

    arr = arr.persist()

    # or as alternative to persist, write to intermediate zarr store
    # dask_zarr_path = os.path.join(temp_dir, "dask_array.zarr")
    # arr.to_zarr(dask_zarr_path, overwrite=True)
    # arr = da.from_zarr(dask_zarr_path)

    se = spatialdata.models.Image2DModel.parse(
        arr,
        scale_factors=[2, 2, 2, 2],
    )

    sdata["blobs_image_processed"] = se

    sdata.write_element("blobs_image_processed")

print(time.time() - start)