xnd-project / numba-xnd

Integrating xnd into numba
https://xnd.io/
5 stars 1 forks source link
numba python xnd

Numba integration for XND

CircleCI branch Codecov branch Code style: black Binder

Code coverage is undercounted since it can't measure Numba jitted functions

Usage

from numba import jit
from xnd import xnd

# register types
import numba_xnd

@numba_xnd.gumath.register_kernel(
    [
        "... * N * M * int64, ... * M * K * int64 -> ... * N * K * int64",
        "... * N * M * float64, ... * M * K * float64 -> ... * N * K * float64",
    ]
)
def simple_matrix_multiply(a, b, c):
    n, m = a.type.shape
    m_, p = b.type.shape
    for i in range(n):
        for j in range(p):
            c[i, j] = 0
            for k in range(m):
                c[i, j] = c[i, j].value + a[i, k].value * b[k, j].value

a = xnd([[1, 2, 3], [4, 5, 6]])
b = xnd([[7, 8], [9, 10], [11, 12]])
c = xnd([[58, 64], [139, 154]])
assert simple_matrix_multiply(a, b) == c

Development

conda env create
conda activate numba-xnd
cd numba
python setup.py develop
cd ..
python setup.py develop
# ready to run project scripts

Run tests:

python -m unittest

Updating xnd_structinfo.c:

pip install git+https://github.com/plures/xndtools.git
structinfo_generator structinfo_config.py
python setup.py develop

Project Structure

The directory structure is meant to mirror the plures/* projects. For each plures project, We have the low level API wrapping the C library (./libndtypes) and then a higher level API wrapping the python level (./ndtypes). The Python level should use the functions in the C level.

For the Python API, we implement lowering some of the same functions/methods that are present at the normal python level, so that a user who wraps a functions in jit will have the same API.

Design Notes

Goals:

Approach:

Assumptions:

Open questions:

Current status:

There is some code that implements parts of this in quansight/numba@xnd, but it converts xnd to numpy like arrays for the gumath kernel creation. Instead we should keep everything as xnd types, so we get the full flexibility of the type system.

API:

Allows simple for indexing xnd types:

@jit(nopython=True)
def sum_1d(a):
    c = 0
    for i in range(a.type.shape[0]):
        c += a[i]
    return c

assert sum_1d(xnd([1, 2, 3])) == 6

@jit(nopython=True)
def sum_attr(a):
    c = 0
    for i in range(a.type.shape[0]):
        c += a[i]['hi']
    return c

assert sum_attr(xnd([{'hi': 1}, {'hi': 2}])) == 3

Any jitted function that takes in all xnd argument and returns xnd arguments can be registered as an gumath kernel.

Can create gumath kernels. The last argument is always the return value:

@register_gumath_kernel([
    'N * M * float64, M * P * float64 -> M * P * float64',
    'N * M * int64, M * P * int64 -> M * P * int64'
])
@jit(nopython=True)
def matrix_multiply(a, b, c):
    n, m = a.type.shape
    m, p = b.type.shape
    for i in range(n):
        for j in range(p):
            c[i][j] = 0
            for k in range(m):
                c[i][j] += a[i][k] * b[k][j]

assert len(matrix_multiply.kernels) == 2

Can also create them with runtime type inspect:

@register_gumath_kernel
@jit(nopython=True)
def add(a, b, c):
    c[tuple()] = xnd(a[tuple()] + b[tuple()]

add(xnd(1), xnd(2))