Closed honno closed 1 year ago
Ah, I see. I think this is easier fixed in the elif
chain at https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/main/torch_np/_dtypes.py#L275
Two options: either do a local import of numpy (local imports we frown upon though) and add a check along the lines of
In [23]: issubclass(np.empty(3).dtype.type, np.generic)
Out[23]: True
or go the duck-typed route: if the argument has a name
attribute, feed arg.name
to sctype_from_string
:
In [25]: dt = np.empty(3).dtype
In [26]: dt.name
Out[26]: 'float64'
Our name aliases we delibrately made to follow numpy, so it should work (or maybe a rarer alias will need to be added) to one of the dicts between https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/main/torch_np/_dtypes.py#L164 and def sctype_from_string
See my latest commit which uses str(arg)
as the fallback in DType.__init__()
. The only place we'd really need to import numpy is if we want to support namespaced dtypes (e.g. np.int64
, np.float64
)... or we could just regex search the str repr to match known dtype strings haha. Supporting namespaced dtypes would be nice as that's what you usually get from numpy.ndarray.dtype
.
I've left those namespaced dtype tests xfailed so this PR is actully good to merge if you're happy with the str(arg)
fallback.
Test failures are real I'm afraid. The problem is not your code though, it's a crutch left from some messy prototyping. Consider
In [12]: tnp.dtype(float)
Out[12]: dtype("float64")
This (surprise!) goes through sctype_from_string
accepting python types, via https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/main/torch_np/_dtypes.py#L244.
This is ugly indeed. Let's add an extra elif
for this in DType.__init__
instead.
We don't want to import numpy in this project if we can avoid it, as NumPy is an optional dependency in PyTorch
As discussed, I don't think this is high prio at the moment, and given the limited budget we have, let's just punt on this for now.
Leaving it open in case we decide to fix this later on.
I've removed the (attempted) fix from me, so this PR now only contains the tests I introduced, where anything failing has been xfailed. IMO we can merge these tests at least?
Starts from https://github.com/Quansight-Labs/numpy_pytorch_interop/pull/116 as that contains a few useful tidbits.
@ev-br see
test_dtype.py
for the failing tests I wrote for converting NumPy dtypes (and thus arrays) to the equivalent interop dtypes. Currently its a mixed bag on what gets converted correctlyI think the solution is pretty simple—we make sure to try using the
str()
of dtypes as a fallback (special casing"bool_"
as"bool"
) before raising an error.Not a priority, just to say I think fixing this is very useful for "comparison testing" between libraries that consume NumPy-proper arrays (including NumPy) and this interop library. See
test_put
as an example of how powerful comparison testing can be with relatively little work.