ImperialCollegeLondon / pyrealm

Development of the pyrealm package, providing an integrated toolbox for modelling plant productivity, growth and demography using Python.
https://pyrealm.readthedocs.io/
MIT License
19 stars 8 forks source link

Array api compatability implementation #269

Open j-emberton opened 1 month ago

j-emberton commented 1 month ago

The array api standard makes it feasible to take advantage of increased interoperability between array libraries. This means that a single array based code can efficiently be deployed that is agnostic to the underlying array type i.e. the arrays can be of numpy, cupy, dask, jax, or any other array api compliant standard. This makes it simple to upgrade a code that is organised around Numpy (single threaded cpu) and deploy this in multithreaded CPU (e.g Dask), or alternative architecture (e.g. CUDA GPU) scenarios. This approach also allows use of lazy evaluation (DASK, JAX) vs eager evaluation (Numpy), and (in principle) Just in Time compilation (JAX).

Details on it aims, implementation and stakeholder scan be found here: Array API

The array_api_compat library has been developed to act as a common interface which simplifies the practical implementation of this intent. This library provides a pure python wrapper which provides aliases and helper functions. Array API compat

There aren't many alternatives to this approach that capture both speedup through CPU parallelisation and GPU support while maintaining an interface which is similar to pure Numpy.

Dependencies: We wouldn't need to add any extra array libraries as hard dependencies. They can just be used on an 'if available' basis. We will want to keep Numpy as the default array processing lib.

Challenges: Identifying which parts of the code are naturally agnostic to where they are processed, and which are constrained to the CPU (e.g. file I/O). We might need to add in checks to move objects in memory to the required device if needed (e.g. if processing on CUDA GPU is requested). The array_api_compat lib adds .device and .to_device functionality to make this simple (hopefully).

I've added a code snippet to demo an example of how this can be implemented and how this leads to array type interoperability. At the core is the namespace query which is able to evaluate which array implementation is being passed to the function and then use the correct array function (matmul in this case) implementation (i.e the numpy version, the dask version etc) to complete the required operation. Simple scalar operations on arrays or basic array element multiplication can be done as normal I think.

Any changes we make should be fully backwards compatible as they extend the functionality of any function or class, rather than modifying it. All tests can remain the same.

from array_api_compat import array_namespace
import array_api_compat.numpy as np
import array_api_compat.torch as torch
from numpy.typing import NDArray
from array_api_compat.torch import Tensor
from typing import Union

APIArray = Union[NDArray,Tensor]

def function(x: APIArray, y: APIArray):

    #determine the namesapce
    xp = array_namespace(x, y)

    # Now use xp as the array library namespace
    return xp.matmul(x,y)

def set_backend(backend):

    if backend == "numpy":
        return np
    elif backend == 'torch':
        return torch
    else:
        raise ValueError

backend = set_backend("numpy")

a = backend.asarray([(1,2,3), (4,5,6)])

b = backend.asarray([(1,2), (3,4), (5,6)])

c = function(a, b)

print(c)