arrayfire / arrayfire-rust

Rust wrapper for ArrayFire
BSD 3-Clause "New" or "Revised" License
814 stars 57 forks source link

Can not use imax_all in typed function [BUG] #300

Closed HammerBu closed 3 years ago

HammerBu commented 3 years ago

imax_all works well in main function.

use arrayfire as af;

fn main() {
    let dims = af::Dim4::new(&[3, 1, 1, 1]);
    let a = af::randu::<f32>(dims);
    af::af_print!("Create 3 random floats on the GPU", a);

    let v_max  = af::imax_all(&a).0;
    println!("Max value: {}", v_max);
}

output:

Create a 5-by-3 matrix of random floats on the GPU [3 1 1 1] 0.6010 0.0278 0.9806 Max value: 0.9805506

However, failed when use typed function

use arrayfire as af;

fn GetMax<T: af::HasAfEnum>(a: &af::Array<T>) {
    let v_max  = af::imax_all(&a).0;
    println!("Max value: {}", v_max);
}

fn main() {
    let dims = af::Dim4::new(&[3, 1, 1, 1]);
    let a = af::randu::<f32>(dims);
    af::af_print!("Create a 5-by-3 matrix of random floats on the GPU", a);

    GetMax(&a);
}

error:

the trait bound <<T as arrayfire::HasAfEnum>::InType as arrayfire::HasAfEnum>::BaseType: arrayfire::Fromf64 is not satisfied

What should i do to use imax_all in the function for different types array(u8, f16, f32)

9prady9 commented 3 years ago

@HammerBu Sorry about the delay, I have been on leave and then sick for past few weeks.

When writing generic functions, you need to add certain bounds for user defined functions based on the set of functions that are used from arrayfire crate so that all trait bounds required by this set is satisfied by the user defined generic function. This is true in general for any rust generic function.

In your case, you can look at the imax_all documentation to see the bounds. The error you showed itself also very clearly indicates the trait bound required. For you reference, given below is the version of GetMax with trait bounds added.

fn GetMax<T: HasAfEnum>(a: &Array<T>)
where
    T: HasAfEnum,
    <<T as HasAfEnum>::InType as HasAfEnum>::BaseType: HasAfEnum + Fromf64
{   
    let v_max  = imax_all(&a).0;
    println!("Max value: {:?}", v_max);
}   

Closing as the issue has been addressed. If you have any follow up questions, feel free to post them on slack channel too.