Closed nschloe closed 6 months ago
Hmm, based on just this example that may be a bit tricky to make work portably. Is the dtype
attribute actually needed (I assume matmul
and result_type
isn't the actual extent of API usage)? And if so, is it actually necessary that there is a single Identity
object? If not, changing the example to this should work:
class Identity:
def __init__(self, dtype):
self.dtype = dtype
x = np.array([1,2,3])
I = Identity(x.dtype)
I'm not sure that this kind of duck typing can be made to work, at least in a guaranteed way, in the current spec. There's nothing in result_type
that says it should support anything other than array
and dtype
inputs, where array
is the array API library array type (or union of array types). It doesn't say anywhere that it should get the dtype from array instances in a duck typable way using .dtype
.
With that being said, ignoring those issues, one issue with your "lowest possible" dtype is that integer and float dtypes don't promote to each other. You could have two objects, one for integers with uint8 and one for floats with float32. Although also note that uint8 and int8 are both equally low in the dtype promotion graph, so choosing one will always cause the other to promote to int16.
TBH, what you likely really want is to use a library like JAX that lets you just write eye()
and optimizes the computation graph without explicitly creating the full array.
Is the dtype attribute actually needed
A typical use is, e.g.,
def cool_linear_solver(A, b, M=None):
m, n = A.shape
assert m == n
M = M or Identity(n)
dt = xp.result_dtype(A, b, M)
x = xp.array(n, dtype)
# ...
# x += M @ (b - A @ x)
# ...
If M
is given, it might "elevate" x
's dtype requirements, but if not, it's just down to A
and b
.
Perhaps the easiest thing to do is to write
dt = xp.result_dtype(A, b, M.dtype)
and make explicit the fact that I expect M
to have a dtype
attribute. @asmeurer's remark about int8
vs uint8
is correct of course, but so far hasn't given me any trouble (arrays are mostly float, something complex). Not sure what a "proper" solution would look like.
This would work based on the snippet you have given, right?
if M is None:
dt = xp.result_dtype(A, b)
M = Identity(n)
else:
dt = xp.result_dtype(A, b, M)
@lucascolley Yeah, but there are several other M
-like entities that would need to factored in. I could build up a list like
dtype_determining_objects = [A, b]
if M is not None:
dtype_determining_objects.append(M)
if N is not None:
dtype_determining_objects.append(N)
if Q is not None:
dtype_determining_objects.append(Q)
dt = xp.result_dtype(*dtype_determining_objects)
but this feels clumsy.
In some of my numerical codes, I have a class
Indentity()
that represents cheap identity matrices. To make it work with things likeresult_type
, I assign it the "lowest possible" dtype,"u1"
:I would now like to make this work with general
xp
arrays. Assigning a concrete implementation for dtype though isn't what I want. UsingNone
as a dtype sounds plausible, but with NumPy gives a warning and then a wrongresult_type
: