wjakob / nanobind

nanobind: tiny and efficient C++/Python bindings
BSD 3-Clause "New" or "Revised" License
2.14k stars 161 forks source link

[BUG]: Complex type does not work in variant #479

Closed awni closed 3 months ago

awni commented 3 months ago

Problem description

If std::complex<float> is included in a variant then it causes the variant to improperly detect non complex types as complex (see example below).

Reproducible example code

#include <iostream>
#include <nanobind/nanobind.h>                                                 
#include <nanobind/stl/complex.h>                                              
#include <nanobind/stl/variant.h>                                              

namespace nb = nanobind;                                                       

using MyType = std::variant<
    nb::float_,
    std::complex<float>>;

void print_type(MyType v) {
  if (auto pv = std::get_if<std::complex<float>>(&v); pv) {                    
    std::cout << "GOT COMPLEX " << std::endl;                                  
  } else if (auto pv = std::get_if<nb::float_>(&v); pv) {                      
    std::cout << "GOT FLOAT" << std::endl;                                     
  }                                                                            
}

NB_MODULE(test, m) {
    m.def("print_type", &print_type);                                          
}

And in Python:

import test

# Correctly routes to float
test.print_type(1.0)

# Incorrectly routes to cmplex
test.print_type(list())
awni commented 3 months ago

Thanks for the fix @yosh-matsuda !

awni commented 3 months ago

@yosh-matsuda and @wjakob there still seems to be an issue here. In the example above if I send in a NumPy array it matches to complex:

import numpy as np
test.print_type(np.array(1.0))

Do you mind to reopen this or shall I file a separate bug report?

yosh-matsuda commented 3 months ago

@awni I think this is expected behavior.

Roughly speaking, the casting in your case seems to be attempted by the following:

yosh-matsuda commented 3 months ago

How about using float instead of nb::float_?

using MyType = std::variant<
    float,
    std::complex<float>>;
awni commented 3 months ago

I see, thanks. What I actually want is a third option in the variant to match which is for nb::ndarray<>. I supposed if I check for that first, it will catch this case. Let me try.

awni commented 3 months ago

PS what works for me is to put the ndarray before the std::complex<float> in the variant. Thanks for your help!