An ONNX-backed array library that is compliant with the Array API standard.
Releases are available on PyPI and conda-forge.
# using pip
pip install ndonnx
# using conda
conda install ndonnx
# using pixi
pixi add ndonnx
You can install the package in development mode using:
git clone https://github.com/quantco/ndonnx
cd ndonnx
# For Array API tests
git submodule update --init --recursive
pixi shell
pre-commit run -a
pip install --no-build-isolation --no-deps -e .
pytest tests -n auto
ndonnx
is an ONNX based python array library.
It has a couple of key features:
It implements the Array API
standard. Standard compliant code can be executed without changes across numerous backends such as like NumPy
, JAX
and now ndonnx
.
import numpy as np
import ndonnx as ndx
import jax.numpy as jnp
def mean_drop_outliers(a, low=-5, high=5):
xp = a.__array_namespace__()
return xp.mean(a[(low < a) & (a < high)])
np_result = mean_drop_outliers(np.asarray([-10, 0.5, 1, 5]))
jax_result = mean_drop_outliers(jnp.asarray([-10, 0.5, 1, 5]))
onnx_result = mean_drop_outliers(ndx.asarray([-10, 0.5, 1, 5]))
assert np_result == onnx_result.to_numpy() == jax_result == 0.75
It supports ONNX export. This allows you persist your logic into an ONNX computation graph.
import ndonnx as ndx
import onnx
# Instantiate placeholder ndonnx array
x = ndx.array(shape=("N",), dtype=ndx.float32)
y = mean_drop_outliers(x)
# Build and save ONNX model to disk
model = ndx.build({"x": x}, {"y": y})
onnx.save(model, "mean_drop_outliers.onnx")
You can then make predictions using a runtime of your choice.
import onnxruntime as ort
import numpy as np
inference_session = ort.InferenceSession("mean_drop_outliers.onnx")
prediction, = inference_session.run(None, {
"x": np.array([-10, 0.5, 1, 5], dtype=np.float32),
})
assert prediction == 0.75
In the future we will be enabling a stable API for an extensible data type system. This will allow users to define their own data types and operations on arrays with these data types.
Array API compatibility is tracked in api-coverage-tests
. Missing coverage is tracked in the skips.txt
file. Contributions are welcome!
Summary(1119 total):
Run the tests with:
pixi run arrayapitests