PyO3 / rust-numpy

PyO3-based Rust bindings of the NumPy C-API
BSD 2-Clause "Simplified" License
1.07k stars 98 forks source link

Add support for safely creating universal functions in Rust #400

Open adamreichold opened 9 months ago

adamreichold commented 9 months ago

This does not support polymorphic universal functions, but the numbers of inputs and outputs are arbitrary.

Closes #399

adamreichold commented 9 months ago

@mhostetter Would you be interested and able to test this branch? Would the functionality available here suffice to handle your use case?

EDIT: The tests examples should give a rough idea how this works.

mhostetter commented 9 months ago

@mhostetter Would you be interested and able to test this branch? Would the functionality available here suffice to handle your use case?

I can try, but I fear my knowledge is too limited to be effective. I've only done an example or two with Py03 and rust-numpy. Do I need to pip install this branch? I'm also probably unsure of how to write the Rust code.

To summarize my goals in words... Write a NumPy ufunc in Rust. This function takes x and y which are single elements of two arrays. The function may also take other configuration parameters, e.g. modulo. An example Rust ufunc of addition in a prime finite field would return (x + y) % modulo. I then expose this Rust-written ufunc to Python via PyO3. Then I can invoke this ufunc on arbitrarily-sized NumPy arrays in Python (using normal NumPy broadcasting). I can also change the modulo in Python at runtime, e.g. rust_ufunc(x, y, 7) or rust_ufunc(x, y, 13).

adamreichold commented 9 months ago

I can try, but I fear my knowledge is too limited to be effective. I've only done an example or two with Py03 and rust-numpy. Do I need to pip install this branch? I'm also probably unsure of how to write the Rust code.

I think to test this branch, you would only need to change your Cargo.toml to replace the version dependency on numpy by a Git one, e.g. replace

numpy = "0.20"

by

numpy = { git = "https://github.com/PyO3/rust-numpy.git", branch = "ufunc" }

To summarize my goals in words... Write a NumPy ufunc in Rust. This function takes x and y which are single elements of two arrays. The function may also take other configuration parameters, e.g. modulo. An example Rust ufunc of addition in a prime finite field would return (x + y) % modulo. I then expose this Rust-written ufunc to Python via PyO3. Then I can invoke this ufunc on arbitrarily-sized NumPy arrays in Python (using normal NumPy broadcasting). I can also change the modulo in Python at runtime, e.g. rust_ufunc(x, y, 7) or rust_ufunc(x, y, 13).

To my understanding, universal functions always take 1-dimensional vectors as inputs (so they have a chance of vectorizing the inner-most loop) and take their outputs explicitly.

So if your modulus is basically fixed, you could inject by capturing it in the closure that defines your ufunc, e.g.

let m = ...;

let add_mod_m = move |[x, y]: [ArrayView1<'_, u64>; 2], [z]: [ArrayViewMut1<'_, u64>; 1]| {
  azip!((x in x, y in y, z in z) *z = (*x + *y) % m);
}); 

let add_mod_m = numpy::ufunc::from_func(py, CString::new("add_mod_m").unwrap(), numpy::ufunc::Identity::Zero, add_mod_m);

module.add("add_mod_m", add_mod_m).unwrap();

If you want to vary m, then you need to add it as a third input parameter that NumPy will then broadcast to a 1-dimensional array (which is trivial using a zero stride).