Code coverage is undercounted since it can't measure Numba jitted functions
from numba import jit
from xnd import xnd
# register types
import numba_xnd
@numba_xnd.gumath.register_kernel(
[
"... * N * M * int64, ... * M * K * int64 -> ... * N * K * int64",
"... * N * M * float64, ... * M * K * float64 -> ... * N * K * float64",
]
)
def simple_matrix_multiply(a, b, c):
n, m = a.type.shape
m_, p = b.type.shape
for i in range(n):
for j in range(p):
c[i, j] = 0
for k in range(m):
c[i, j] = c[i, j].value + a[i, k].value * b[k, j].value
a = xnd([[1, 2, 3], [4, 5, 6]])
b = xnd([[7, 8], [9, 10], [11, 12]])
c = xnd([[58, 64], [139, 154]])
assert simple_matrix_multiply(a, b) == c
conda env create
conda activate numba-xnd
cd numba
python setup.py develop
cd ..
python setup.py develop
# ready to run project scripts
Run tests:
python -m unittest
Updating xnd_structinfo.c
:
pip install git+https://github.com/plures/xndtools.git
structinfo_generator structinfo_config.py
python setup.py develop
The directory structure is meant to mirror the plures/*
projects. For each plures project,
We have the low level API wrapping the C library (./libndtypes
) and then a higher level API
wrapping the python level (./ndtypes
). The Python level should use the functions
in the C level.
For the Python API, we implement lowering some of the same functions/methods that are present at the normal python level, so that a user who wraps a functions in jit
will have the same API.
Goals:
Approach:
Assumptions:
Open questions:
xnd_master_t
struct?Current status:
There is some code that implements parts of this in quansight/numba@xnd
, but it converts xnd to numpy like arrays
for the gumath kernel creation. Instead we should keep everything as xnd types, so we get the full flexibility of the type system.
API:
Allows simple for indexing xnd types:
@jit(nopython=True)
def sum_1d(a):
c = 0
for i in range(a.type.shape[0]):
c += a[i]
return c
assert sum_1d(xnd([1, 2, 3])) == 6
@jit(nopython=True)
def sum_attr(a):
c = 0
for i in range(a.type.shape[0]):
c += a[i]['hi']
return c
assert sum_attr(xnd([{'hi': 1}, {'hi': 2}])) == 3
Any jitted function that takes in all xnd argument and returns xnd arguments can be registered as an gumath kernel.
Can create gumath kernels. The last argument is always the return value:
@register_gumath_kernel([
'N * M * float64, M * P * float64 -> M * P * float64',
'N * M * int64, M * P * int64 -> M * P * int64'
])
@jit(nopython=True)
def matrix_multiply(a, b, c):
n, m = a.type.shape
m, p = b.type.shape
for i in range(n):
for j in range(p):
c[i][j] = 0
for k in range(m):
c[i][j] += a[i][k] * b[k][j]
assert len(matrix_multiply.kernels) == 2
Can also create them with runtime type inspect:
@register_gumath_kernel
@jit(nopython=True)
def add(a, b, c):
c[tuple()] = xnd(a[tuple()] + b[tuple()]
add(xnd(1), xnd(2))