data-apis / array-api-strict

Strict implementation of the Python array API (previously numpy.array_api)
http://data-apis.org/array-api-strict/
Other
7 stars 4 forks source link

Allow changing the default dtypes #38

Open asmeurer opened 4 months ago

asmeurer commented 4 months ago

See https://data-apis.org/array-api/latest/API_specification/data_types.html#default-data-types and https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.default_dtypes.html#array_api.info.default_dtypes.

We should add flags to the set_array_api_strict_flags to configure these away from the NumPy defaults.

One concern here is that some instances of moving from float64 to float32, we might have to just downcast the result from NumPy, meaning the computation will still happen in float64, producing a result that could be different from a library that actually does everything in float32. This should likely be worked around wherever possible by downcasting the input before computing rather than the output.

asmeurer commented 4 months ago

We could also add behavior to emulate missing dtypes. This would require rewriting the existing code little bit, so I'm only really included to implement this if people ask for it. It would help map to libraries like pytorch, but at the same time, people will just test against those libraries so it isn't strictly necessary for array-api-strict to be the provider of this behavior.