data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
212 stars 44 forks source link

RFC: require that dtypes obey Python hashing rules #582

Open NeilGirdhar opened 1 year ago

NeilGirdhar commented 1 year ago

Python's documentation promises that: "The only required property is that objects which compare equal have the same hash value…" However, NumPy dtypes do not follow this requirement. As discussed in https://github.com/numpy/numpy/issues/7242, dtype objects, their types, and their names all compare equal despite hashing unequal. Could the Array API promise that this will no longer be the case?

rgommers commented 1 year ago

That seems fine to me to explicitly specify. float32 == 'float32' should clearly return False. In NumPy it's a bit messy:

>>> import numpy as np
>>> np.float32 == 'float32'
False
>>> np.dtype(np.float32) == 'float32'
True

Only the first example is relevant for the array API standard, so I think this will be fine to specify since NumPy already complies.

This one there is a problem in NumPy however:

>>> np.dtype(np.float64) == float
True

That can be considered a clear bug though, should be fixed in NumPy.

NeilGirdhar commented 1 year ago

Only the first example is relevant for the array API standard, so I think this will be fine to specify since NumPy already complies.

So you're saying that np.dtype(np.float32) == 'float32' will be true or false?

That can be considered a clear bug though, should be fixed in NumPy.

Agreed.

What about np.float32 == np.dtype(np.float32)?

This also violates Python's hashing invariant.

rgommers commented 1 year ago

What about np.float32 == np.dtype(np.float32)?

I agree that it's a bug technically. Not 100% sure that the NumPy team will want that changed, but I hope so (and a proposal for a major release is in the works, so that could go into it). For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.

So you're saying that np.dtype(np.float32) == 'float32' will be blocked or not?

That's more for the NumPy issue tracker, but if it were up to me then yes.

For this issue tracker, I'm +1 on adopting language in the standard like: "All objects in this standard must adhere to the following requirement (as required by Python itself): objects which compare equal have the same hash value".

NeilGirdhar commented 1 year ago

For this issue tracker, I'm +1 on adopting language in the standard like: "All objects in this standard must adhere to the following requirement (as required by Python itself): objects which compare equal have the same hash value".

That would be amazing. That's exactly what I was hoping for.

I agree that it's a bug technically. Not 100% sure that the NumPy team will want that changed, but I hope so (and a proposal for a major release is in the works, so that could go into it). For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.

Okay, thanks for explaining. If the above language were adopted, NumPy could implement that by making xp.float32 not simply equal to np.dtype(np.float32), but rather a special dtype object that doesn't have the pernicious behavior.

rgommers commented 1 year ago

Let's give it a bit of time to see if anyone sees a reason not to add such a requirement. I can open a PR after the holidays.

NeilGirdhar commented 1 year ago

For the array API standard it's not an issue, because there is no dtype constructor/function in the main namespace.

Just noticed this comment. It is currently an issue in NumPy's implementation of the Array API:

import numpy.array_api as xp
xp.float32 == xp.float32.type  # True!

This is because xp.float32 points to an object np.dtype(np.float32). For this to be fixed, NumPy would just need a new dtype class for use in its Array API xp.

With the language you suggested above, NumPy would be forced to do this to become compliant 😄 .

So you're saying that np.dtype(np.float32) == 'float32' will be blocked or not?

That's more for the NumPy issue tracker, but if it were up to me then yes.

Same thing here, I think. NumPy will probably reject this for their own namespace (np), but if you adopt that language, they would have to fix it in the array API (xp).

Incidentally, I assume you want numpy.array_api.float32 to compare equal to jax.array_api.float32? Since there is no root project to provide a base implementation of dtypes, you may need to standardize how dtype.__hash__ and comparison work.

rgommers commented 1 year ago

xp.float32 == xp.float32.type # True!

There is no float32.type in the standard. That it shows up with numpy.array_api.float32 is because the dtype objects there are aliases to the regular numpy ones, rather than new objects. That was a shortcut I think, because adding new dtypes is a lot of work. So that's one place where currently numpy.array_api doesn't 100% meet its goal of being completely minimal.

Incidentally, I assume you want numpy.array_api.float32 to compare equal to jax.array_api.float32?

No, definitely not. No objects from two different libraries should ever compare equal, unless they're indeed the same object.

NeilGirdhar commented 1 year ago

So that's one place where currently numpy.array_api doesn't 100% meet its goal of being completely minimal.

Ok! Thanks for explaining.

No, definitely not. No objects from two different libraries should ever compare equal, unless they're indeed the same object.

So to do things like checking that two arrays have the same dtype, or creating a NumPy array that has the same type as a Jax array, we'll need mappings like:

m = {jax.array_api.float32: np.array_api.float32, ...}

And code like

np.array_api.ones_like(some_jax_array)  # works today, in either direction.

is impossible, yes? You need:

np.array_api.ones(some_jax_array.shape, dtype=m[some_jax_array.dtype])
rgommers commented 1 year ago

So to do things like checking that two arrays have the same dtype ...

Having to use library-specific constructs should not be needed - if so, we're missing an API I'd say. More importantly: mixing arrays from different libraries like this is a bit of an anti-pattern. You can't do much with that, neither library has kernels for functions that use both array types, so you're probably relying on implicit conversion of one to the other.

So in this case, let me assume that x is a numpy array, y a JAX array and you're wanting to use functions from x (numpy):

# First retrieve the namespace you want to work with
xp = x.__array_namespace__()
# Use DLPack or the buffer protocol to convert a CPU JAX array to a NumPy array
y = xp.asarray(y)
# Now we can compare dtypes:
if x.dtype == y.dtype == xp.float32:
    # If the same dtypes, do stuff

# Or, similarly:
if xp.isdtype(x, xp.float32) and xp.isdtype(y, xp.float32):

is impossible, yes? You need:

yes indeed

I'm actually a little surprised JAX accepts numpy arrays. It seems to go against its philosophy; TensorFlow, PyTorch and CuPy will all raise. When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.

JAX is also annotating its array inputs as array_like, but it doesn't mean the same as for NumPy:

>>> jnp.sin([1, 2, 3])
...
TypeError: sin requires ndarray or scalar arguments, got <class 'list'> at position 0

All this stuff is bug-prone:

>>> jnp.sin(np.array([1, 2, 3]))
Array([0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

>>> jnp.sin(np.ma.array([1, 2, 3], mask=[True, False, True]))  # bug in user code here, because JAX silently discards mask
Array([0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

>>> np.sin(np.ma.array([1, 2, 3], mask=[True, False, True]))
masked_array(data=[--, 0.9092974268256816, --],
             mask=[ True, False,  True],
       fill_value=1e+20)
NeilGirdhar commented 1 year ago

More importantly: mixing arrays from different libraries like this is a bit of an anti-pattern. You can't do much with that, neither library has kernels for functions that use both array types, so you're probably relying on implicit conversion of one to the other.

Okay, makes sense. I haven't been very conscious about this because (as you pointed out) Jax implicitly converts. I will be more careful.

y = xp.asarray(y)

I think this is where I'm confused. Somehow numpy has to know what its equivalent dtypes are for Jax's dtypes even though they don't compare equal? Or will it produce a numpy array with a Jax dtype? As this seems to work:

In [12]: x = jnp.ones(10, jnp.bfloat16)

In [14]: np.asarray(x)
Out[14]: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=bfloat16)

When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.

Very interesting. I wonder what the Jax team would say.

rgommers commented 1 year ago

I think this is where I'm confused. Somehow numpy has to know what its equivalent dtypes are for Jax's dtypes even though they don't compare equal? Or will it produce a numpy array with a Jax dtype?

NumPy knows the dtype, as does JAX. This conversion uses the Python buffer protocol or DLPack, both of which are protocols explicitly meant for exchanging data in a reliable way (that includes dtype, shape, endianness, etc.). So the asarray call will produce a numpy array with a numpy dtype, and to do so numpy does not need to know anything specifically about JAX.

When you call jnp.xxx(a_numpy_array), JAX will also make a copy always I believe, since it doesn't want to share memory. An explicit copy made by the user is clearer and more portable.

Very interesting. I wonder what the Jax team would say.

Let's try to find out:) This section of the JAX docs only explains why JAX doesn't accept list/tuple/etc., but I cannot find an explanation of why it does accept numpy arrays and scalars. @shoyer or @jakevdp, would you be able to comment on why JAX implements a limited form of "array-like"?

Also, in addition to bug with masked arrays above, here is another bug:

>>> jnp.sin(np.float64(1.5))  # silent precision loss here, downcasting to float32
Array(0.997495, dtype=float32)
>>> jax.__version__
'0.4.1'
NeilGirdhar commented 1 year ago

So the asarray call will produce a numpy array with a numpy dtype, and to do so numpy does not need to know anything specifically about JAX.

In that case, there should be a way to convert dtypes using both the buffer protocol or DLPack? Something more efficient than:

def x_to_y_dtype(some_xp_dtype: DType, yp: ArrayInterfac) -> DType:
  xp = some_xp_dtype.__array_interface__ # doesn't exist
  x = xp.ones((), dtype=some_xp_dtype)
  yp.asarray(x)
  return yp.dtype

Should dtypes have a __array_namespace__ attribute? Currently, they don't. So, the above function can't be written unless you know xp.

rgommers commented 1 year ago

No, those protocols are specifically for exchanging data (strided arrays/buffers). A dtype without data isn't very meaningful. You could exchange a size-1 array if needed, or a 'float32' string representation, or whatever works.

NeilGirdhar commented 1 year ago

No, those protocols are specifically for exchanging data (strided arrays/buffers).

I understand, but in order to exchange data, they have to be able to convert dtypes. So, that dtype conversion is happening somehow, and I was just wondering if that conversion can be accessed by the user.

rgommers commented 1 year ago

It's not user-accessible, it's all under the hood.

Specifically for JAX you have a shortcut, because it reuses NumPy dtypes directly:

>>> type(jnp.float32)
<class 'jax._src.numpy.lax_numpy._ScalarMeta'>
>>> type(jnp.float32.dtype)
<class 'numpy.dtype[float32]'>
NeilGirdhar commented 1 year ago

(Thanks for all the patient explanations!)

jakevdp commented 1 year ago

@shoyer or @jakevdp, would you be able to comment on why JAX implements a limited form of "array-like"?

JAX avoids implicit conversion of Python sequences, because it can hide severe performance issues. When something like x = [i for i in range(10000)] is passed to the XLA compiler, it is passed as a list of 10000 XLA scalars. We found this to be a common mistake people made, and decided to disallow it. This is discussed at Non-Array Inputs: Numpy vs. JAX.

On the other hand, np.arange(10000) is a single XLA array, and doesn't have this problem. On CPU, the transfer can even be done in most cases in a zero-copy fashion, although on accelerators there will be a device transfer cost.

Also, in addition to bug with masked arrays above, here is another bug:

>>> jnp.sin(np.float64(1.5))  # silent precision loss here, downcasting to float32
Array(0.997495, dtype=float32)
>>> jax.__version__
'0.4.1'

This is working as intended: JAX only allows 64-bit values when explicitly enabled; see Double (64-bit) Precision. This was an early design decision that the team recognizes as non-ideal, but it has proven difficult to change because so many users depend on the bit truncation behavior and enjoy the accelerator-friendly type safety it confers.

Specifically for JAX you have a shortcut, because it reuses NumPy dtypes directly:

>>> type(jnp.float32)
<class 'jax._src.numpy.lax_numpy._ScalarMeta'>
>>> type(jnp.float32.dtype)
<class 'numpy.dtype[float32]'>

The reason JAX defines this is that it made the early design choice to not distinguish between scalars and zero-dimensional arrays. np.float32, despite its common use (probably stemming from NumPy's anything-goes approach to dtype equality/identity that is the original reason for this issue) is not a dtype, but rather it is a scalar float32 type. When JAX added the jax.numpy convenience wrapper around its core functionality, it needed dtype-specific scalar constructors similar to NumPy's np.float32, np.int32, etc. that would output appropriate zero-dimensional arrays. It does not make sense for jnp.float32 to be its own type, because unlike numpy there are no dedicated scalar types in JAX.

We could have defined simple functions named float32, int32, etc. but because np.float32 is so commonly used as a stand-in for np.dtype('float32'), we needed the scalar constructor functions to be something that np.dtype would treat as a dtype, and so the _ScalarMeta classes were born.

>>> jnp.sin(np.ma.array([1, 2, 3], mask=[True, False, True]))  # bug in user code here, because JAX silently discards mask
Array([0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

To my knowledge, this bug has never come up (probably because masked arrays are so rarely used in practice). I'll raise it in the JAX repo.

rgommers commented 1 year ago

Thanks for the context @jakevdp!

This is working as intended: JAX only allows 64-bit values when explicitly enabled; see Double (64-bit) Precision. This was an early design decision that has proven difficult to change

I knew that 64-bit precision must be explicitly enabled, but this is still surely a bug? The expected behavior is an exception, saying asking the user to explicitly downcast if the precision loss is fine, or to enable 64-bit precision. Or at the very least emit a warning. Silent downcasting is terrible - it may be okay for deep learning, but it typically isn't for general purposes numerical/scientific computing.

NeilGirdhar commented 1 year ago

On the other hand, np.arange(10000) is a single XLA array, and doesn't have this problem. On CPU, the transfer can even be done in most cases in a zero-copy fashion, although on accelerators there will be a device transfer cost.

That might be unfortunate when one goes from converting a CPU program to a GPU one? It might be nice to be able to enable a flag that makes this into a runtime error. That way I can remove all of my unintentional jax/numpy array mixing.

jakevdp commented 1 year ago

Silent downcasting is terrible - it may be okay for deep learning, but it typically isn't for general purposes numerical/scientific computing.

I think you hit on the key point here: there are different communities with different requirements, and JAX attempts, maybe clumsily, to serve them all. If you are doing deep learning and care about performance over potential precision loss, you can set JAX_ENABLE_X64=0. If you are using JAX for general purposes and don't want this, you can set JAX_ENABLE_X64=1. The fact that the former is the default was an early decision based on initial uses of the package; we've actively discussed changing it, but it would be a big change and there are many pros and cons that must be weighed.

It's a difficult problem to solve well in a single package: it's worth noting that NumPy's answer to requests to serve the needs of deep learning is essentially no, which is a defensible choice given the package's early design decisions.

leofang commented 1 year ago

What about np.float32 == np.dtype(np.float32)?

This has been one of the few NumPy things that I dislike (and that would be moot for Array API). In NumPy, np.float32 is a Python type

>>> type(np.float32)
<class 'type'>

whereas np.dtype(np.float32) is a dtype instance

>>> type(np.dtype(np.float32))
<class 'numpy.dtype[float32]'>

The former is needed, IIUC, only because of the need to construct NumPy scalars. Once NumPy removes this concept (how about NumPy 2.0, @seberg? 🙂) we can (and should) make them equivalent!

seberg commented 1 year ago

I do not really want to touch removing scalars from NumPy; maybe someone more confident about it can push for such a thing...

Maybe to be clear, to change NumPy here I see now other way then (I think this is what Ralf said):

If you remove scalars, then np.float32(0) would have to raise an error, helps, but also noisy?

I don't see another way, so you can put it into np.array_api or np.ndarray.__array_namespace__, but np.float32 is borked and I don't see how to fix it except doing the above, and probably doing it very slowly.

leofang commented 1 year ago

I'd argue that if there's any design that could bring us closer to full compliance in the main namespace with the standard, we should consider it, and removing scalars in favor of 0D arrays is one of them. It's been a source of confusion with no obvious gain except for keeping legacy code work. It's been made clear that no accelerator library would support it. Also, removing scalars would keep the type promotion lattice cleaner.

So,

Is arr.dtype == np.dtype(...) good enough

Yes.

then np.float32(0) would have to raise an error, helps, but also noisy?

Not at all noisy 🙂

and probably doing it very slowly

All I care is 1. eventual compliance, and 2. reducing both user confusion and developer (you) workload 🙂 If this is something that could take 1 full developer year to do, so be it.

seberg commented 1 year ago

You can change NumPy relatively easily. The problem is dealing with whether pandas and others need involved change. So the issue about scalars (and to some degree also this in general), is that it is very much holistic and I can zoom in on NumPy and give you a branch where scalars may still be in the code base but should never be created... But I am not sure I am equipped with understanding what would happen to pandas or ... if we do it.

(I am also admittedly the one person who hates everything about NumPy scalars, but isn't sure that scalars themselves are all that bad.)

rgommers commented 1 year ago

Scalars themselves aren't that bad, if only they weren't created by operations like np.sum(x) and x[()]. If those started returning 0-D arrays that probably wouldn't even break that much in downstream libraries, the problem is the corner cases in end user code where they're used in places that Python scalars are expected.

I have a sense that it's doable in principle, but that it's one step too far for NumPy 2.0.

seberg commented 1 year ago

Yes, I would be willing to experiment with the "getting scalars more right part". But also yes: even that needs at least testing to have confidence that it would be 2.0 scoped (i.e. few enough users actually notice and if they do mostly in harmless ways).

rgommers commented 1 year ago

Yes, I would be willing to experiment with the "getting scalars more right part". But also yes: even that needs at least testing to have confidence that it would be 2.0 scoped (i.e. few enough users actually notice and if they do mostly in harmless ways).

I'd be interested in helping out with an effort like this. But I don't think I can be one of the two "champions" (see here) for this one, I already signed up for enough other stuff for NumPy 2.0.

asmeurer commented 1 year ago

Going back to the original discussion, another annoying thing NumPy does is

>>> np.dtype('float64') == None
True

which has tripped us up a few times in the test suite.

asmeurer commented 1 year ago

Would it be possible in NumPy to make np.float64 just be np.dtype('float64') by implementing __call__ and __instancecheck__ on it (the actual type of a scalar float64 would become a hidden _float64 class)? That wouldn't remove scalars but it would make it so that there is only one dtype object representing any given dtype.

NeilGirdhar commented 1 year ago

That wouldn't remove scalars but it would make it so that there is only one dtype object representing any given dtype.

I love this idea. This would be a step towards what Leo wanted above: "any design that could bring us closer to full compliance in the main namespace with the standard". I think if we don't do what you're suggesting, it will be a source of confusion that np.float32 is a type, but np.array_api.float32 is a dtype.

by implementing __call__ and __instancecheck__

(And __subclasscheck__.)