Closed bernardelli closed 2 years ago
I think PyArrayDescr::num
would best serve your use case as you appear to use only enumerated types (we do not support more at the moment but are hoping to do so in the future). (The NPY_TPYES
enumeration should contain the necessary constants.)
(As an aside, I don't think implementation
really needs to be a macro but could be a generic function with a T: Element
bound instead?)
As another aside, while it will not be as efficient as unsafely picking the array info, I think you should be able to achieve the same result safely using something like
#[pyfn(m)]
fn generic(a: &PyAny) {
fn implementation<T: Element>(a: &PyArray1<T>) {
// ...
}
macro_rules! dispatch {
($a:ident, $($ty:ty),+) => {
$(
let a: PyResult<&PyArray1<$ty>> = $a.extract();
if let Ok(a) = a {
return implementation(a);
}
)+
panic!("array has unsupported element type");
};
}
dispatch!(a, u8, i32, f64);
}
or a bit more data-oriented using
#[derive(FromPyObject)]
enum GenericArray1<'py> {
U8(&'py PyArray1<u8>),
I32(&'py PyArray1<i32>),
F64(&'py PyArray1<f64>),
}
#[pyfn(m)]
fn generic(a: GenericArray1) {
fn implementation<T: Element>(a: &PyArray1<T>) {
// ...
}
match a {
GenericArray1::U8(a) => implementation(a),
GenericArray1::I32(a) => implementation(a),
GenericArray1::F64(a) => implementation(a),
}
}
Thank you for the quick reply! I'll be testing this solution this week and I'll let you know if it works. I guess the issue is closed by now :smile:
The commit https://github.com/PyO3/rust-numpy/commit/2e2b5763997708f534996876bf9a1e4932f09e60 breaks this piece of code.
I implemented a
pyo3
function that uses a template to generate code for different numpy data types.I used a match expression based on the value of a
numpy::DataType
variable:Since the
numpy::DataType
has been dropped in v0.16, I wonder what's the best way achieve this behavior.