NSLS-II / PyXRF

Fluorescence fitting GUI.
http://nsls-ii.github.io/PyXRF
BSD 3-Clause "New" or "Revised" License
31 stars 23 forks source link

Fix issues with recent versions of Dask and Numba #313

Closed dmgav closed 9 months ago

dmgav commented 9 months ago

The PR contains fixes for two issues with Dask (1) and Numba/Numpy (2):

  1. Issue with Dask. Starting with Dask v2023.9.3 the workaround used to force Dask workers to close HDF5 files stopped working. The issue was created in the distributed package: https://github.com/dask/distributed/issues/8452 The working solution (also a workaround) can be found in https://gist.github.com/dmgav/9fa69d1e507eff46e8098f082b0a1611 and includes overriding default serializers and deserializers for h5py files and datasets. The code is also included in this PR description below.
#############  File 'deserializers.py'  ####################

import distributed.protocol.h5py
from distributed.protocol.serialize import dask_serialize, dask_deserialize

deserialized_files = set()

def serialize_h5py_file(f):
    if f and (f.mode != "r"):
        raise ValueError("Can only serialize read-only h5py files")
    filename = f.filename if f else None
    return {"filename": filename}, []

def serialize_h5py_dataset(x):
    header, _ = serialize_h5py_file(x.file if x else None)
    header["name"] = x.name if x else None
    return header, []

def deserialize_h5py_file(header, frames):
    import h5py

    filename = header["filename"]
    if filename:
        file = h5py.File(filename, mode="r")
        deserialized_files.add(file)
    else:
        file = None
    return file

def deserialize_h5py_dataset(header, frames):
    file = deserialize_h5py_file(header, frames)
    name = header["name"]
    dset = file[name] if (file and name) else None
    return dset

def set_custom_serializers():
    import h5py

    dask_serialize.register((h5py.Group, h5py.Dataset), serialize_h5py_dataset)
    dask_serialize.register(h5py.File, serialize_h5py_file)
    dask_deserialize.register((h5py.Group, h5py.Dataset), deserialize_h5py_dataset)
    dask_deserialize.register(h5py.File, deserialize_h5py_file)

def close_all_files():
    while deserialized_files:
        file = deserialized_files.pop()
        if file:
            file.close()
import h5py
import dask
import dask.array as da
import distributed
from dask.distributed import Client, wait
import numpy as np

import logging
logger =  logging.Logger(__name__)

def run_example():

    from deserializers import set_custom_serializers, close_all_files

    print(f"Version of Dask: {dask.__version__}")
    print(f"Version of Distributed: {distributed.__version__}")
    print(f"===============================")

    # Create HDF5 file
    print("Creating HDF5 file")
    fln = "test.h5"
    with h5py.File(fln, "w") as f:
        dset = f.create_dataset("data", data=np.random.random(size=(100, 100)), chunks=(10, 10), dtype="float64")

    print("Creating client")
    client = Client()

    client.run(set_custom_serializers)
    set_custom_serializers()

    # Process the file
    print("Loading and processing data")
    with h5py.File(fln, "r") as f:

        data = da.from_array(f["data"], chunks=(10, 10))
        sm_fut = da.sum(data, axis=0).persist(scheduler=client)
        sm = sm_fut.compute(scheduler=client)
        print(f"sm={sm}")

    client.run(close_all_files)
    close_all_files()

    # Try to open file for writing
    print("Attempting to open file for writing")
    try:
        with h5py.File(fln, "r+") as f:
            print("File was opened for writing !!!")
    except OSError as ex:
        logger.exception("Failed to open file for writing: %s", ex)

    print("Closing client")
    client.close()

if __name__ == "__main__":
    run_example()
  1. Issue with Numpy/Numba. The function that implements 'snip' method for background subtraction kept failing with List index out of range error when run with Numba JIT (it was working correctly without Numba JIT). The issue was tracked to the use of numpy.convolve function. The part of the function using convolution was reimplemented using an explicit loop, multiplication and subtraction. In the initial tests, the new solution appears to work at least as fast as the old solution when compiled with Numba JIT.

Original code based on np.convolve:

    A = s.sum()
    background = np.convolve(background, s) / A
    # Trim 'background' array to imitate the np.convolve option 'mode="same"'
    mg = len(s) - 1
    n_beg = mg // 2
    n_end = n_beg - mg  # Negative
    background = background[n_beg:n_end]

Replacement code:

    def convolve(background, s):
        s_len = len(s)
        n_beg = (s_len - 1) // 2
        A = s.sum()
        source = np.hstack(
            (
                np.zeros(n_beg, dtype=background.dtype),
                background,
                np.zeros(s_len - n_beg, dtype=background.dtype),
            )
        )
        for n in range(len(background)):
            background[n] = np.sum(source[n : n + s_len] * s) / A

    convolve(background, s)