Open fritzo opened 5 years ago
@fritzo and I came up with what seems like a nice design in #442. Inspired by the notation in that PR, we might write the matmul
example above as
@function
def matmul(
x: Array,
y: Array
) -> Dependent[lambda x, y: Array[x.dtype, (x.shape[0], y.shape[1])]]:
return x.matmul(y)
where Dependent
, like Fresh
in #442. takes a lambda
as an argument that takes the domains of arguments x
and y
at the time that matmul
is called and returns the result domain.
One question for this approach is whether to allow more specific dependent type annotations for x
and y
analogous to the shape variables a,b,c
above and how useful that would be for shape checking.
@eb8680 interesting... I guess we could even enforce constraints, something like
def assume(constraint, value):
if not constraint:
raise TypeError
return value
@function
def matmul(
x: Dependent[lambda x: assume(isinstance(x, Array) and len(x.shape) == 2, x)],
y: Dependent[lambda y: assume(isinstance(y, Array) and len(y.shape) == 2, y)],
) -> Dependent[lambda x, y: assume(x.dtype == y.dtype and x.shape[1] == y.shape[0],
Array[x.dtype, (x.shape[0], y.shape[1])])]:
return x.matmul(y)
or with a Where[...]
type.
I like the idea of having assume
or Where
, where Where
takes a base type and a dependent predicate. We could even allow their use simultaneously.
A downside of this notation is that it quickly becomes verbose when specifying complicated shape constraints. Maybe it would still be useful to have a notation like the original post specifically for Array
that desugars to Dependent
/Where
. This also has the virtue of nudging users away from specifying shape arithmetic constraints which are harder to get right or compose.
For example, we could write a batched matmul
as:
@function
def matmul(
x: Array['real', (..., "a", "b")],
y: Array['real', (..., "b", "c")],
) -> Array['real', (..., "a", "c")]:
return x.matmul(y)
which might desugar to
@function
def matmul(
x: Where[Dependent[lambda x: Array[x.dtype, x.shape]], lambda x: len(x.shape) >= 2],
y: Where[Dependent[lambda y: Array[y.dtype, y.shape]], lambda y: len(y.shape) >= 2],
) -> Where[Dependent[lambda x, y: Array['real', broadcast_shape(x.shape[:-2], y.shape[:-2]) + (x.shape[-2], y.shape[-1])]],
lambda x, y: is_broadcastable(x.shape[:-2], y.shape[:-2]) and x.shape[-1] == y.shape[-2] and x.dtype == y.dtype]: # generated from shape variables
return x.matmul(y)
One implication of such an expressive design is that only ground types could be supported in inputs and outputs of actual funsors e.g. we couldn't allow the construction of funsor.Variable(name, Dependent[...])
.
Another set of operations that pose a challenge for this design are those where shapes depend on an argument value, even when that value is guaranteed to be an integer or boolean literal. In #442, the function in Fresh
takes domains as arguments rather than values. The simplest one is .sum(dim=dim)
:
@function
def sum_one_dim(
x: Array,
dim: int
) -> Dependent[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim+1:]]]:
return x.sum(dim)
For this to work as written, lambda x, dim: ...
would have to take the value of dim
, not its type int
. This forces us to choose a consistent behavior for the other argument x
:
lambda
take the value of x
by default, rather than its .output
? This does not seem ideal if x
could be a Funsorsum_one_dim
depends on the value of dim
(e.g. dim: Value[int]
instead of dim: int
)?ValueDependent
for types that depend on values, e.g. the return type of sum_one_dim
?sum_one_dim
should either use a less specific signature or be written as a Funsor term?Example of 2:
@function
def sum_one_dim(
x: Array,
dim: Value[int]
) -> Dependent[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim+1:]]]:
return x.sum(dim)
Currently
@funsor.torch.function
must wrap each sized matmul individually, e.g.Could we generalize this to einsum-style syntax with strings as size variables?