scverse / anndata

Annotated data.
http://anndata.readthedocs.io
BSD 3-Clause "New" or "Revised" License
580 stars 154 forks source link

General array-api support #1195

Open ivirshup opened 1 year ago

ivirshup commented 1 year ago

Please describe your wishes and possible alternatives to achieve the desired result.

With the array-api it should be possible to have some level of support for a greater variety of array types, including pytorch and jax. E.g.:

We have a proof of concept with using array-api compat on top of cupy arrays.

Challenges

Dispatch

We currently use type based dispatch to select appropriate methods for different array types. It's not immediately clear to me how we integrate the array-api here. Potentially with a Protocol or ABC?

Specialized code for performance

We currently have a bunch of non-standard code for various operations for performance reasons. For instance, using pandas indexing code on numpy arrays.

IO dispatch

IO dispatch is particularly difficult. It's likely we just need to use the dlpack based interchange here.

Views

Current view infrastructure is very type based. Right now, I think we'd either need to rework this or have special code for each array type, which kinda defeats the purpose. Maybe an xarray dependency would be easier here.

Sparse

The array API currently doesn't account for sparse types:

A notable problem here is IO, since dlpack based interchange would densify the array.

Downstream libraries

Finally, how do downstream libraries deal with this? It creates the possibility of passing these libraries array types they haven't seen before. What is the appropriate behavior on their end? Do you ask the user to convert? Does the library convert on its own (even if that could be extremely expensive)?

flying-sheep commented 2 months ago

We can define a type if none exists using the same pattern that abc.ABCMeta/abc.ABC/typing.Protocol uses: __instancecheck__ (here combined with is_array_api_obj):

import array_api_compat

class ArrayApiMeta(type):
    def __instancecheck__(cls, instance: object) -> bool:
        return array_api_compat.is_array_api_obj(instance)

class ArrayApi(meta=ArrayApiMeta):
    pass