data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
69 stars 22 forks source link

Result of `numpy.sum` with `float32` input is `float64` #152

Open mdhaber opened 2 months ago

mdhaber commented 2 months ago

IIUC the standard states that input dtype is to be preserved for floating point types when dtype is unspecified.

dtype - data type of the returned array. If None, the returned array must have the same data type as x, unless x has an integer data type

However,

from array_api_compat import numpy as np
x = np.arange(10, dtype=np.float32)
np.sum(x).dtype  # dtype('float64')
asmeurer commented 2 months ago

This is part of #127. This was changed in 2023.12. I don't see a need to keep the previously standardized behavior intact. Unfortunately, I haven't yet started work on updating array-api-compat for 2023.12, but I plan to start that work now.

mdhaber commented 2 months ago

Interesting. I didn't notice that the old standard required that the output dtype be the default floating point type. Glad that was changed!