data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
75 stars 25 forks source link

dtype from string helper function #71

Closed toddrjen closed 9 months ago

toddrjen commented 11 months ago

numpy.dtype supports creating a dtype from a string, such as numpy.dtype('float32'). torch.dtype does not, and as far as I can tell array-api-compat has no general way to do this other than gettar on the entire API namespace (which is not safe).

It would be helpful to have a compatibility wrapper that allows loading of dtypes from strings. This is a particular problem in array-api-compat since there is no way for a function to transparently use array-api-compat internally with dtypes other than by using strings.

rgommers commented 11 months ago

@toddrjen thanks for the issue. Can you please add why would you want this? If you can write getattr(xp, 'float32') you can also write xp.float32, right? I'm missing the actual use case or example code of why you'd like to use strings to begin with.

toddrjen commented 11 months ago

You can use getattr, but that exposes the entire namespace, and could easily pull in things that aren't dtypes if the wrong thing is specified so is generally unsafe.

A simple example is a function that allows someone to specify an output dtype regardless of the array used, in a way that the function caller doesn't need use array-api-compat directly themselves. Really the only way to do that is using a string.

toddrjen commented 11 months ago

Alternatives would be to put all the dtypes in their module that is important into the general namespace, or put them all in a dict or simplenamespace.

rgommers commented 11 months ago

simple example is a function that allows someone to specify an output dtype regardless of the array used, in a way that the function caller doesn't need use array-api-compat directly themselves. Really the only way to do that is using a string.

That doesn't sound quite right. Say you are writing a package that uses array-api-compat, and you are exposing a function like this to the users of your package:

def func(x, dtype=None):
    """
    x: array
        any compliant array type
    dtype: dtype, optional
        If given, the dtype of the output array
    """
    # calculate something here using xp-compliant functions, then put it in a `result` array
    if dtype is not None:
        result = xp.astype(result, dtype)

    return result

Then if your user uses NumPy, they can call the function as func(x, dtype=np.float32), and if they use PyTorch they call it like func(x, dtype=torch.float32).

Using strings for the above seems like an anti-pattern. If I'm missing your point @toddrjen, then please work it out with a code snippet.

toddrjen commented 11 months ago

The idea is to be able to write processing routines that work identically for pytorch and numpy. Avoiding having to check the type of the array is why I am using array-api-compat, so having to specify that the dtype belongs to pytorch or numpy defeats the purpose.

So yes, your code snippet is correct. But the idea is for the user to not have to know or care whether a pytorch or numpy array is being passed.

vnmabus commented 11 months ago

But you could easily implement that functionality if desired using getattr and some previous validation, at least for the mandatory dtypes of the array API standard (which form a finite and fixed set). If you want to allow possible extensions, such as int128, that is also doable using regexes. Am I missing anything here?

rgommers commented 11 months ago

But the idea is for the user to not have to know or care whether a pytorch or numpy array is being passed.

Then I think I disagree with your goal here. There are two situations:

  1. The user of the API is writing generic code,
  2. The user is writing code specific to one application (or data analysis script, or similar) and one array library

For (1), they do have a generic namespace at hand, so can write xp.float32. For (2), they don't need it to be generic, so using an actual dtype object like torch.float32 is better than "programming with strings".

I suspect that you are missing something important here: the end user is always using a single array library. The standard also doesn't have I/O functionality, so a user is starting with something like:

import numpy as np

x = np.load('my_data.npy')
rgommers commented 9 months ago

Since the discussion has run its course for now, I'll close this as wontfix (can always be revisited in case of more comments).

Thanks all!