Open LucaMarconato opened 3 weeks ago
I include a minimal example to reproduce the observed behaviour.
If arr.persist()
is called, the code completes in ~10s, but if arr.persist()
is commented, the code compeletes in ~50 s (i.e. 10s for each scale -> some_function
is called 5 times).
As an alternative to using .persist()
, writing to a zarr store, and then loading it back, evidently 'solves' the problem in a similar way.
import os
import tempfile
import time
import dask.array as da
import numpy as np
import spatialdata
from spatialdata.datasets import blobs
sdata = blobs()
start = time.time()
with tempfile.TemporaryDirectory() as temp_dir:
sdata.write(os.path.join(temp_dir, "sdata_blobs_dummy.zarr"))
def _some_function(arr):
arr = arr * 2
time.sleep(10)
return arr
arr = sdata["blobs_image"].data
arr = da.map_blocks(_some_function, arr, dtype=float, meta=np.array((), dtype=float))
arr = arr.persist()
# or as alternative to persist, write to intermediate zarr store
# dask_zarr_path = os.path.join(temp_dir, "dask_array.zarr")
# arr.to_zarr(dask_zarr_path, overwrite=True)
# arr = da.from_zarr(dask_zarr_path)
se = spatialdata.models.Image2DModel.parse(
arr,
scale_factors=[2, 2, 2, 2],
)
sdata["blobs_image_processed"] = se
sdata.write_element("blobs_image_processed")
print(time.time() - start)
As observed by @ArneDefauw, unnecessary loading operations are performed when calling
write_multiscale()
on a list of lazy tensors derived fromto_multiscale()
.Optimizing the order in which the data is computed and written to disk, so to avoid the loading of the same chunks 2+ times, would probably lead to a drastic performance improvement, up to 10-fold.