Closed thomasjpfan closed 1 year ago
This looks like a bug in numpy.array_api. reshape
in the spec only accepts a tuple (in all versions of the spec) https://data-apis.org/array-api/draft/API_specification/generated/array_api.reshape.html#array_api.reshape
Looks like broadcast_to
and the axis
argument of permute_dims
also have this issue. We might consider changing this in the spec. Most other places that accept shape or axis accept both an int or tuple of ints.
Let's see what the decision here is https://github.com/data-apis/array-api/issues/608. We should probably update numpy.array_api either way.
Of course, we can still work around this in the wrappers here. But it wouldn't be portable to other array API compliant libraries (not that any necessarily exist). But this should be too hard to work around in upstream code either.
I think the numpy.array_api
support for ints is a bug indeed, and requiring a tuple is good practice here. It's easy to fix this up in scikit-learn, so I suggest doing it there rather than in this repo.
I'll close this issue because torch.reshape
is following the spec correctly and there is nothing to change without this repo.
torch.reshape
requires an tuple as the shape:Note that integer
shape
works withnumpy.array_api
: