data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
211 stars 44 forks source link

`Identity()` operator dtype #739

Closed nschloe closed 6 months ago

nschloe commented 8 months ago

In some of my numerical codes, I have a class Indentity() that represents cheap identity matrices. To make it work with things like result_type, I assign it the "lowest possible" dtype, "u1":

import numpy as np
from array_api_compat import array_namespace

class Identity:
    dtype = np.dtype("u1")

    def __matmul__(self, x):
        return x

I = Identity()
x = np.array([1,2,3])

xp = array_namespace(x)
print(xp.result_type(x, I))
int64

I would now like to make this work with general xp arrays. Assigning a concrete implementation for dtype though isn't what I want. Using None as a dtype sounds plausible, but with NumPy gives a warning and then a wrong result_type:

DeprecationWarning: in the future the `.dtype` attribute of a given datatype object must be a valid dtype instance. `data_type.dtype` may need to be coerced using `np.dtype(data_type.dtype)`. (Deprecated NumPy 1.20)
  print(xp.result_type(x, I))
float64
rgommers commented 8 months ago

Hmm, based on just this example that may be a bit tricky to make work portably. Is the dtype attribute actually needed (I assume matmul and result_type isn't the actual extent of API usage)? And if so, is it actually necessary that there is a single Identity object? If not, changing the example to this should work:

class Identity:
    def __init__(self, dtype):
        self.dtype = dtype

x = np.array([1,2,3])
I = Identity(x.dtype)
asmeurer commented 8 months ago

I'm not sure that this kind of duck typing can be made to work, at least in a guaranteed way, in the current spec. There's nothing in result_type that says it should support anything other than array and dtype inputs, where array is the array API library array type (or union of array types). It doesn't say anywhere that it should get the dtype from array instances in a duck typable way using .dtype.

With that being said, ignoring those issues, one issue with your "lowest possible" dtype is that integer and float dtypes don't promote to each other. You could have two objects, one for integers with uint8 and one for floats with float32. Although also note that uint8 and int8 are both equally low in the dtype promotion graph, so choosing one will always cause the other to promote to int16.

TBH, what you likely really want is to use a library like JAX that lets you just write eye() and optimizes the computation graph without explicitly creating the full array.

nschloe commented 7 months ago

Is the dtype attribute actually needed

A typical use is, e.g.,

def cool_linear_solver(A, b, M=None):
    m, n = A.shape
    assert m == n
    M = M or Identity(n)

    dt = xp.result_dtype(A, b, M)
    x = xp.array(n, dtype)

    # ...
    # x += M @ (b - A @ x)
    # ...

If M is given, it might "elevate" x's dtype requirements, but if not, it's just down to A and b.

Perhaps the easiest thing to do is to write

dt = xp.result_dtype(A, b, M.dtype)

and make explicit the fact that I expect M to have a dtype attribute. @asmeurer's remark about int8 vs uint8 is correct of course, but so far hasn't given me any trouble (arrays are mostly float, something complex). Not sure what a "proper" solution would look like.

lucascolley commented 7 months ago

This would work based on the snippet you have given, right?

if M is None:
    dt = xp.result_dtype(A, b)
    M = Identity(n)
else:
    dt = xp.result_dtype(A, b, M)
nschloe commented 7 months ago

@lucascolley Yeah, but there are several other M-like entities that would need to factored in. I could build up a list like

dtype_determining_objects = [A, b]
if M is not None:
   dtype_determining_objects.append(M)
if N is not None:
   dtype_determining_objects.append(N)
if Q is not None:
   dtype_determining_objects.append(Q)

dt = xp.result_dtype(*dtype_determining_objects)

but this feels clumsy.