PyO3 / rust-numpy

PyO3-based Rust bindings of the NumPy C-API
BSD 2-Clause "Simplified" License
1.11k stars 106 forks source link

Update to 0.16 breaks dtype template #289

Closed bernardelli closed 2 years ago

bernardelli commented 2 years ago

The commit https://github.com/PyO3/rust-numpy/commit/2e2b5763997708f534996876bf9a1e4932f09e60 breaks this piece of code.

I implemented a pyo3 function that uses a template to generate code for different numpy data types.

I used a match expression based on the value of a numpy::DataType variable:

[pyfunction]
pub fn generic_test(input: &PyAny) -> PyResult<PyObject> {
    let (dtype, shape) = pick_array_info(input)?;

    if shape.len() > 1 {
        return Err(PyValueError::new_err(
            "`input` must have only one dimension",
        ));
    }

    macro_rules! implementation {
        ($dtype:ident) => {{
            let array: &PyArray1<$dtype> = convert_to_array(input);

            println!("Got {}", stringify!($dtype));

            let result = my_reduce(array);

            let gil = Python::acquire_gil();
            let py = gil.python();

            let result_to_py = PyArray1::from_slice(py, &[result]);

            Ok(result_to_py.to_object(py))
        }};
    }
    return match dtype {
        DataType::Bool => implementation!(u8),
        DataType::Int8 => implementation!(i8),
        [ . . . ]
        DataType::Float64 => implementation!(f64)
    };

Since the numpy::DataType has been dropped in v0.16, I wonder what's the best way achieve this behavior.

adamreichold commented 2 years ago

I think PyArrayDescr::num would best serve your use case as you appear to use only enumerated types (we do not support more at the moment but are hoping to do so in the future). (The NPY_TPYES enumeration should contain the necessary constants.)

adamreichold commented 2 years ago

(As an aside, I don't think implementation really needs to be a macro but could be a generic function with a T: Element bound instead?)

adamreichold commented 2 years ago

As another aside, while it will not be as efficient as unsafely picking the array info, I think you should be able to achieve the same result safely using something like

#[pyfn(m)]
fn generic(a: &PyAny) {
    fn implementation<T: Element>(a: &PyArray1<T>) {
        // ...
    }

    macro_rules! dispatch {
        ($a:ident, $($ty:ty),+) => {
            $(
            let a: PyResult<&PyArray1<$ty>> = $a.extract();
            if let Ok(a) = a {
                return implementation(a);
            }
            )+

            panic!("array has unsupported element type");
        };
    }

    dispatch!(a, u8, i32, f64);
}

or a bit more data-oriented using

#[derive(FromPyObject)]
enum GenericArray1<'py> {
    U8(&'py PyArray1<u8>),
    I32(&'py PyArray1<i32>),
    F64(&'py PyArray1<f64>),
}

#[pyfn(m)]
fn generic(a: GenericArray1) {
    fn implementation<T: Element>(a: &PyArray1<T>) {
        // ...
    }

    match a {
        GenericArray1::U8(a) => implementation(a),
        GenericArray1::I32(a) => implementation(a),
        GenericArray1::F64(a) => implementation(a),
    }
}
bernardelli commented 2 years ago

Thank you for the quick reply! I'll be testing this solution this week and I'll let you know if it works. I guess the issue is closed by now :smile: