inducer / pytato

Lazily evaluated arrays in Python
Other
8 stars 16 forks source link

[Bug] Scalar dtype incompatibility with numpy #414

Open kaushikcfd opened 1 year ago

kaushikcfd commented 1 year ago
>>> a_np = np.random.rand(10, 4).astype(np.float32)
>>> np.minimum(a_np, 255).dtype
dtype('float32')
>>> pt.minimum(pt.make_data_wrapper(a_np), 255).dtype
dtype('float64')
kaushikcfd commented 1 year ago

We should probably use np.min_scalar_type in get_common_dtype_of_ary_or_scalars?

inducer commented 1 year ago

Jax is another important precedent here. They have put considerable thought into their type promotion semantics. It might be reasonable to follow their lead, which I think might, in this case, match the semantics you're after.