Open imh opened 9 months ago
Hi imh,
This is a great topic and I was thinking to give a crack at Jax again a few weeks ago. So you are quite on point here! :)
I have always wanted to provide GPU support to Colour, and the first thing I tried 3.5 years ago was Jax. You will even see some discussions I had with the developers here: https://github.com/google/jax/issues/3689. We ended up during GSoC using Cupy as it was much more mature. The main issue with it is that it is tied to NVidia GPU, and I don't have a Windows machine nor it is trivial to test with them, so the PR is parked for now and highlighted a few interesting things.
Suffice to say that I actually did start prototyping with Jax: https://github.com/colour-science/colour/pull/625/commits/69f8d3bf443ad4ee149068301ecc12db424b0f1b#diff-3c8cad174551c24c304eefcdbb2a880c3c7e4ef58f83df1d85af343a6fd3f194R1
One of the core problem was that the Jax import was incredibly slow but besides that it kinda worked. I'm certainly happy to give that a new crack but we would go down the backend road as it would allow us to swap Jax for Cupy or something else.
Keen to hear your thoughts!
Okay, after posting this, I did some reading about numpy's NEPs and other library's takes on this kind of thing. tl;dr there's a long history of proposals with the python array api standard finally gaining traction, which may be a good candidate here:
Instead of this:
def user_facing_function(x: ArrayLike):
return np.exp(x)
we do this:
def user_facing_function(x: ImNotSureProbablyAProtocol):
xp = get_namespace(x)
return xp.exp(x)
Because support for the standard is generally in experimental state for most libraries, get_namespace
should depend on an opt-in config along these lines:
def get_namespace(*xs):
if not USE_ARRAY_API:
return np
# The rest of this is would be exactly the implementation suggested in [NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html#appendix-a-possible-get-namespace-implementation)
# or it would use https://github.com/data-apis/array-api-compat?tab=readme-ov-file
# as `return array_api_compat.array_namespace(x, y)`
...
Again, I'm happy to implement this if agree with the direction.
In more detail:
__array_ufunc__
for overriding numpy__array_function__
for overriding numpyxp = get_array_module(...)
xp = get_namespace(...)
and the python array API standardThe design choices appendix in NEP 37 lays out a lot of the accumulated thinking around them (opt in vs opt out, explicit vs implicit, local vs non-local vs global), which is probably useful for what colour might choose to do. To make colour composable with user code, local control via call time xp = np.get_array_module(...)
or xp = np.get_namespace(...)
seems preferable to a config flag.
NEP 47 (the array API standard) seems to be the one with the most traction across the numeric python ecosystem, so I'm recommending that.
So far, both the array API standard as well as the implementations in numpy, jax, cupy, dask, etc are all still experimental, so it seems good to keep it opt-in, like sklearn is currently doing it.
However, static typing could be more complex if colour does this (static typing design doc in the API spec). A protocol should suffice, but it may not be as clean as you have it right now.
This way does increase the complexity and maintenance burden a bit, relative to just changing imports, but still seems worthwhile.
Thanks for digging deeper, I came across some of those in recent years but haven't checked recently admittedly! The sklearn approach seems sensible. A cursory look at skimage did not show anything going into that direction though, I might missed it!
Most of our functions are starting with a call to colour.utilities.as_float_array
and colour.utilities.as_int_array
definitions. They could be a good entry point for the namespace jazz. I'm certainly happy for you to give it a go!
Paging @tjdcs for VIS!
Cheers,
Thomas
Thanks you know I love performance.
Someone told me about jaxtyping the other day supporting annotation for array shape. I'm not sure if it is true, but it would be very helpful.
https://jax.readthedocs.io/en/latest/jax.typing.html
One thing that concerns me long term for maintainability is putting a lot of configurable dependency switches in. Maybe we should consider just completely switching... Or depending on both.... Not sure that the latter is really the best but it might be required anyway.
FYI I'm not dropping this, but I'll return to it once https://github.com/data-apis/array-api-compat/issues/83 is in, since it seems like the simplest, closest to "standard" way to do it.
One thing that concerns me long term for maintainability is putting a lot of configurable dependency switches in.
It's worth mentioning that writing "array-agnostic" code like @imh 's example with xp = get_namespace(...)
_(get_namespace
is now a legacy alias for array_namespace
btw)_, does not introduce a dependency on other libraries like JAX. The idea is that all of the functions that are needed are contained in the xp
namespace, which comes from input arrays, rather than being imported at this level. A namespace that complies with the standard will 'just work', without colour even needing to know which library it is.
For now, the only additional dependency would be array-api-compat
, which bridges the gap between the current implementations and full compliance with the standard, but eventually that will not be needed and the namespaces will be accessible directly via x.__array_namespace__
.
If you'd like more info on the standard, see this talk from SciPy 2023, and my blog post describing how we're using it in SciPy might be helpful to see the perspective of an 'array-consumer' library. (and feel free to ask me any questions!)
Thanks @lucascolley, fantastic insight! Will read/watch your links.
I watched the presentation and that was super helpful! I foresee a few issues and I'm not sure how to solve them with what offers array-api-compat
:
We have plenty of functions that accept scalar (and more generally ArrayLike) input and this obviously does not work:
import array_api_compat
def gamma_function(x, y):
xp = array_api_compat.array_namespace(x, y)
return xp.power(x, 1 / y)
gamma_function(0.18, 2.2)
gamma_function(0.18, [2.0, 2.1, 2.2, 2.3])
It means that we would need to now do that:
import numpy as np
import array_api_compat
def gamma_function(x, y):
xp = array_api_compat.array_namespace(x, y)
return xp.power(x, 1 / y)
gamma_function(np.array(0.18), np.array(2.2))
gamma_function(np.array(0.18), np.array([2.0, 2.1, 2.2, 2.3]))
Which for a user a significant loss in user experience, not only that but that would break a TON of downstream dependent code. This slide surprised me because it seems like the focus was put on the low-level libraries and not the libraries that are using them, e.g., Colour:
We have a lot of datasets, matrices of all sorts that are imported:
CAT_VON_KRIES: NDArrayFloat = np.array(
[
[0.4002400, 0.7076000, -0.0808100],
[-0.2263000, 1.1653200, 0.0457000],
[0.0000000, 0.0000000, 0.9182200],
]
)
"""
*Von Kries* chromatic adaptation transform.
References
----------
:cite:`CIETC1-321994b`, :cite:`Fairchild2013ba`, :cite:`Lindbloom2009g`,
:cite:`Nayatani1995a`
"""
Those are using Numpy at the moment but how to handle those with array_api_compat
? This was actually one of the big problems when we tried to adopt Cupy.
Is there a better place to discuss about those? They might merit more visibility as I'm sure we are not/will not be the only package with such questions!
We have plenty of functions that accept scalar (and more generally ArrayLike) input and this obviously does not work
Here's how we do it in SciPy: we want to keep accepting array-likes (which often end up being turned into np
arrays during computation) with the NumPy backend, but not for alternative backends. If you want to use a JAX array, you must pass a proper array. But other unrecognised inputs are used with the NumPy backend.
We do that by wrapping array_namespace
as follows:
def array_namespace(*arrays):
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
# here we could wrap the namespace if needed
return np_compat
arrays = [array for array in arrays if array is not None]
arrays = compliance_scipy(arrays)
return array_api_compat.array_namespace(*arrays)
def compliance_scipy(arrays):
for i in range(len(arrays)):
array = arrays[i]
if isinstance(array, np.ma.MaskedArray):
raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
elif isinstance(array, np.matrix):
raise TypeError("Inputs of type `numpy.matrix` are not supported.")
if isinstance(array, (np.ndarray, np.generic)):
dtype = array.dtype
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
raise TypeError(f"An argument has dtype `{dtype!r}`; "
f"only boolean and numerical dtypes are supported.")
elif not is_array_api_obj(array):
try:
array = np.asanyarray(array)
except TypeError:
raise TypeError("An argument is neither array API compatible nor "
"coercible by NumPy.")
dtype = array.dtype
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
message = (
f"An argument was coerced to an unsupported dtype `{dtype!r}`; "
f"only boolean and numerical dtypes are supported."
)
raise TypeError(message)
arrays[i] = array
return arrays
This means that array_api_compat.numpy
is always returned unless the experimental environment variable is set, maintaining current behaviour for all array types for the time being. When we do set the env variable, we try to coerce array-likes with np.asanyarray
before passing to array_api_compat
.
This treatment of NumPy as the default backend allows you to have two code branches in functions, based on whether xp
is (the compat version of) NumPy. In SciPy, since we have a lot of compiled code which only works with NumPy for now, this is needed quite a lot, but for pure Python + NumPy code, most of it should be easily convertible to array-agnostic code.
We have a lot of datasets, matrices of all sorts that are imported (using
np.array
)
While not guaranteed by the standard, every library we have worked with so far can coerce np
arrays with xp.asarray
. That is good enough for now - at some point in the future, when the goal is really to be as portable as possible, the standard includes some specification of device-interchange with DLPack.
~If performance overhead is the concern, I suppose (just brainstorming) you could have a function like get_dataset(dataset_name, xp)
which returns xp.asarray({dataset data})
. As mentioned above, this would complicate things for static typing, so converting from NumPy may be the best option for now.~ thinking more, this would still involve a device copy. If you want to boost performance for a specific library, the best thing would probably be to use that library's creation functions conditionally on the namespace and device of the input.
Is there a better place to discuss about those? They might merit more visibility as I'm sure we are not/will not be the only package with such questions!
There definitely needs to be some guidance on topics like this (perhaps a follow-up to my blog post), but it is still very early days in adoption. Dare I say you are a bit ahead of the curve here 😉 - packages which depend on the "low-level libraries" will have to wait for those libraries to become compatible before being able to become fully compatible themselves.
In SciPy, we are still figuring out the best way to go about things, and will likely upstream tools/helpers which are generally useful across the NumPy ecosystem to a new repo at some point. Once the foundations are in place and settled in the core libraries, it will be easier to give advice to downstream packages.
That said, feel free to open an issue on the array-api-compat or array-api repos if you'd prefer to discuss over there!
I've gone ahead an pinned this issue. In general we are heavy users of array "stuff". It's probably 90% of our code in someway, and I think this is a really important and exciting path. Being able to support multiple back-ends would be very powerful.
With that said... I also don't want perfect to be the enemy of the good. Maybe it's worth continuing to pursue JAX integration and updateing more of our numpy specific code? Thoughts @KelSolaar @imh
Thanks @lucascolley!
@tjdcs: The Scipy approach looks sensible to me as it ensures that we are not breaking backward compatibility which is something I would like to avoid at all cost. There is too much code dependent on Colour that we need to be a bit careful. It seems like the stars are better aligned to start working on that compared to 4 years ago, so I'm keen to explore that more for sure.
+1 to the thanks, @lucascolley!
As far as arraylike datasets go, it seems like they should be promoted into whatever the user is using, since they're originating outside array-api-land. For example, Pointer's gamut remains a numpy array where it's defined, and if the user passes in jax arrays, then the gamut constant gets promoted from a numpy array to a jax array.
For functions accepting arraylike, it's too bad we can't reasonably tell which arguments are deliberately set by the user and which are just default parameters, otherwise we'd probably still want promote to whatever the user deliberately passed in. We could approximate that with an array_namespace
that promotes as follows (pure python --> numpy --> other array api). For example:
array_namespace(jax_array, python_list)
= jax apiarray_namespace(dask_array, numpy_array)
= dask apiarray_namespace(python_list, numpy_array)
= numpy apiarray_namespace(jax_array, cupy_array)
= errorI have a nebulous bad feeling about that hiding user errors though, so we could alternately just promote (pure python --> array api) and handle constants separately:
array_namespace(jax_array, python_list)
= jax apiarray_namespace(dask_array, numpy_array)
= errorarray_namespace(python_list, numpy_array)
= numpy apiarray_namespace(jax_array, cupy_array)
= errorIn SciPy it is more simple. If you want one of your array inputs to be a JAX array, they all must be (see e.g. https://github.com/scipy/scipy/issues/18286#issuecomment-1507795400).
These array inputs tend to not have default parameters. Where they do, it should be possible to change the default to a new object in a backwards-compatible way, so that you can distinguish between default input and genuine Python-list / None
input.
It would be worth giving https://github.com/scipy/scipy/issues/18286 a look if you are considering going down the array API route. Clearly there is a lot more complexity for a package like SciPy, but it might be helpful to spell out the overall aims/strategy.
A few comments on some things discussed here:
Which for a user a significant loss in user experience, not only that but that would break a TON of downstream dependent code. This slide surprised me because it seems like the focus was put on the low-level libraries and not the libraries that are using them, e.g., Colour:
I wouldn't worry too much about this. The data I discussed on this slide was primarily used to get a reasonable list of APIs for inclusion in the initial version of the standard. Since then, more APIs and behaviors have been added to the standard based on user feedback. If something you need is missing, open an issue in the array-api repo.
While not guaranteed by the standard, every library we have worked with so far can coerce np arrays with xp.asarray. That is good enough for now - at some point in the future, when the goal is really to be as portable as possible, the standard includes some specification of device-interchange with DLPack.
asarray does support inputs that support the buffer protocol. There's also DLPack, which is probably preferable as the buffer protocol is CPU only (see https://data-apis.org/array-api/latest/design_topics/data_interchange.html).
I'm curious how scipy handles this. Does scipy not also have hard-coded data, or is that only in C and Fortran codes like fft?
We have plenty of functions that accept scalar (and more generally ArrayLike) input and this obviously does not work:
The way I see it is this: if a user is using a library that doesn't accept array-like inputs (for example, torch
functions generally do not allow lists as inputs), then they should already have this expectation that everything should be an array first, and will carry that expectation to colour. If they are using a library like numpy that does allow it, then all the functions in array-api-compat
will allow that (array-api-compat tries to maintain library behaviors that aren't required by the standard), so passing a list will work. That's with the minor caveat that you would need a wrapper like scipy's to make list-only arguments to array_namespace
default to numpy (maybe we should add a default_library
flag to array_namespace
).
Of course, this does mean your internal function calls and tests will need to be a little more rigorous about calling asarray first. But that's really just a special case of a general fact, which is that if you want to support the array API in colour, you will need to only use array API-compatible APIs and behaviors everywhere. Functions only accepting array as inputs is one instance of that, but there are many others that you would need to update your code for as well (this slide from my presentation gives an idea of the sorts of changes typically required).
I'm curious how scipy handles this
I'm not sure whether this has come up in public API functions yet. In tests, we just create NumPy arrays and convert (copy if a different device) with xp.asarray
.
@imh : Have you put some more thoughts into all this by any chance?
I had a brief look into the Colour codebase today. My one thought was that you may want to wait for a resolution to https://github.com/data-apis/array-api/pull/589, given that everything here is statically typed.
I also didn't spy any parametrized tests, which are the easiest way to test multiple xp
backends without duplicating test code. Not sure whether this would require anything extra to work with unittest
.
Linking some array compatibility files for reference:
Hi folks, I'm back from a long github break (became a dad woohoo, took a sabbatical). I agree that it doesn't make much sense to go forward until https://github.com/data-apis/array-api/pull/589 is in.
Description
JAX includes a numpy compatible
jax.numpy
module which has a bunch of nice features (automatic differentiation, jit compilation, vectorized mapping, GPU runtime, js export). They've taken great pains to make sure it's usually as simple as swappingimport numpy as np
forimport jax.numpy as np
. LIkewise (but less extensively) for thejax.scipy
module.I'd like to do some optimization for which it would be really convenient to automatically differentiate some of the great stuff you've implemented and export it to js. It should be as simple as changing
import numpy as np
around the library:Changing the type signatures probably has more degrees of freedom we can choose, but is basically the same.
I'd be happy to implement it, but don't want to make a PR that you don't want.
I expect that the added maintenance burden would be pretty minimal.