pyccel / pyccel-cuda

Cuda extension to pyccel
MIT License
1 stars 0 forks source link

Clearer use of ndarray_type_registry #17

Open EmilyBourne opened 1 year ago

EmilyBourne commented 1 year ago

Currently we have a ndarray_type_registry which is used for both N-D arrays and Cuda arrays. In order to use the same object for Cuda arrays we have code such as the following:

dtype = self.find_in_ndarray_type_registry(self._print(rhs.dtype), rhs.precision)
dtype = dtype[3:]
code_init += 'array_fill_{0}(({1}){2}, {3});\n'.format(dtype, declare_dtype, self._print(rhs.fill_value), self._print(lhs))

This code could be clearer as it is not clear where the [3:] comes from unless we are familiar with find_in_ndarray_type_registry. It is also not very future proof as it will be very hard to find statements like this should the prefix ever change length. Furthermore if the ndarray_type_registry is used for cuda types too then its docstring should be updated to make it clear that it affects more than just ndarrays.

To clear this up I recommend changing the function to find_in_array_type_registry Code can then do:

dtype = self.find_in_array_type_registry(self._print(rhs.dtype), rhs.precision)
code_init += 'array_fill_{0}(({1}){2}, {3});\n'.format(dtype, declare_dtype, self._print(rhs.fill_value), self._print(lhs))

While array code would instead do:

dtype = 'nd_'+self.find_in_array_type_registry(self._print(rhs.dtype), rhs.precision)

Alternatively to avoid code duplication, we could also keep self.find_in_ndarray_type_registry:

def find_in_ndarray_type_registry(dtype, precision):
    type_name = self.find_in_array_type_registry(dtype, precision)
    return f'nd_{type_name}'
EmilyBourne commented 3 weeks ago

Related to old branch