jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
693 stars 39 forks source link

where method do not works with pytorch #38

Closed eserie closed 3 years ago

eserie commented 3 years ago

When one of the two conditions is a float64 dtype and the second a number (int, float), where() method raises an error:

RuntimeError: expected scalar type float but found double