mind-inria / mri-nufft

Doing non-Cartesian MR Imaging has never been so easy.
https://mind-inria.github.io/mri-nufft/
BSD 3-Clause "New" or "Revised" License
51 stars 10 forks source link

Suggestion: Use `plum` for Multiple Dispatch to Optimize Device-Specific NUFFT Operations #205

Closed gRox167 closed 4 weeks ago

gRox167 commented 1 month ago

Description

Hello, and thank you for your excellent work on the mri-nufft project! I’d like to suggest multiple dispatch programming paradigm in this project, specifically using the plum package, to enhance device-specific handling of Non-Uniform Fast Fourier Transform (NUFFT) operations.

Motivation

As MRI reconstruction is computationally intensive, it’s vital to leverage different hardware (CPUs, GPUs, etc.) efficiently. Currently, device-specific code often requires branching to check hardware compatibility, which can increase complexity and reduce clarity. Multiple dispatch provides a robust alternative by automatically selecting the correct implementation based on device type, significantly enhancing flexibility, readability, and performance.

As you may know, PyTorch itself is using multiple dispatch to manage different operation on different device and different datatype. Developers from PyTorch said they are basically re-writing Julia.

Why Use plum for Multiple Dispatch?

The plum package provides a streamlined, Pythonic way to implement multiple dispatch, allowing us to easily define separate, optimized functions for each device. Here’s why plum is an ideal choice for this project:

  1. Simplicity and Python Integration: plum is lightweight and integrates smoothly with Python, allowing clear and concise device-specific methods without requiring complex changes to the existing code structure.

  2. Improved Performance: With plum, each device-specific NUFFT function is dispatched without internal conditionals, allowing faster, direct execution. This can be particularly advantageous in high-frequency MRI reconstruction tasks where efficiency is key.

  3. Maintainability and Readability: Device-specific functions are encapsulated and clearly defined, reducing conditional checks and making the codebase more readable. With plum, it’s clear which function handles which device, improving maintainability and reducing potential errors.

  4. Easy Extensibility: plum enables you to add new device implementations (e.g., for TPUs) or optimized kernels without altering existing code. This future-proofs the project as it grows to support new hardware.

Proposed Implementation

By adopting plum, we can define distinct functions for each device and each operator, which will be dynamically dispatched based on the input type (CPU array, GPU tensor, etc.) and operator type. Here’s an example of what this could look like:

from plum import dispatch

@dispatch
def nufft(data: np.array, op: FinufftBaseOp, *args):
    # NUFFT implementation for CPU with Finufft

@dispatch
def nufft(data: np.array, op: GpunufftBaseOp, *args):
    # NUFFT implementation for CPU but using Gpunufft

@dispatch
def nufft(data: torch.Tensor, op: CufinufftBaseOp, sens_map: torch.Tensor, device: "cpu" | "cuda",*args):
    # NUFFT implementation for Tensor, with sens_map on device

@dispatch
def nufft(data: GPUArray, op: GpunufftBaseOp, sens_map: torch.Tensor, device: "cpu" | "cuda", *args):
    # NUFFT implementation for GPUArray

This way, plum will automatically select the appropriate NUFFT function based on the input device type and op, simplifying the code. Further, 2D NUFFT op w./w.o. z-stacks and w./w.o. multicoil can also be used as a dispatch parameter.

And we can also build up Type Hierarchy to further simplify the code.

Potential Challenges

Migrating to multiple dispatch might require initial restructuring of the operator interface and adding this package could have overhead and make this project bulkier. Thank you for considering this suggestion! I would be glad to discuss further details and assist with implementation if this aligns with the project’s goals.

paquiteau commented 4 weeks ago

Hello there,

Thanks for bringing plum to our attention, but after thinking about it, we are not going to use plum inside MRI-NUFFT. Here are the reason motivating our decision:

If you wish to wrap MRI-NUFFT with plum for your own stuff, feel free to do it (and please talk to us about it!). Meanwhile I will probably think about using plum for other projects of mine.

I am closing this for now.

[^1]: Maybe we will add here later, and we probably should for linting purposes, I really wished there was some good array typing library with flexible shape and dtype

gRox167 commented 4 weeks ago

Thanks for your responses! Indeed there is no significant improvements in performance.

BTW, for the tensor type and shape package, I don't know if jaxtyping suits for your needs, It supports numpy, torch and Jax for explicit shape and implicit shape type checking.