pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Support size variables in funsor.domains (dependent types) #214

Open fritzo opened 5 years ago

fritzo commented 5 years ago

Currently @funsor.torch.function must wrap each sized matmul individually, e.g.

@function(reals(2,3), reals(3,4), reals(2,4))
def matmul_2_3_4(x, y):
    return x.matmul(y)

Could we generalize this to einsum-style syntax with strings as size variables?

@function(reals("a", "b"), reals("b", "c"), reals("a", "c"))
def matmul(x, y):
    return x.matmul(y)
eb8680 commented 3 years ago

@fritzo and I came up with what seems like a nice design in #442. Inspired by the notation in that PR, we might write the matmul example above as

@function
def matmul(
    x: Array,
    y: Array
) -> Dependent[lambda x, y: Array[x.dtype, (x.shape[0], y.shape[1])]]:
    return x.matmul(y)

where Dependent, like Fresh in #442. takes a lambda as an argument that takes the domains of arguments x and y at the time that matmul is called and returns the result domain.

One question for this approach is whether to allow more specific dependent type annotations for x and y analogous to the shape variables a,b,c above and how useful that would be for shape checking.

fritzo commented 3 years ago

@eb8680 interesting... I guess we could even enforce constraints, something like

def assume(constraint, value):
    if not constraint:
        raise TypeError
    return value

@function
def matmul(
    x: Dependent[lambda x: assume(isinstance(x, Array) and len(x.shape) == 2, x)],
    y: Dependent[lambda y: assume(isinstance(y, Array) and len(y.shape) == 2, y)],
) -> Dependent[lambda x, y: assume(x.dtype == y.dtype and x.shape[1] == y.shape[0],
                                   Array[x.dtype, (x.shape[0], y.shape[1])])]:
    return x.matmul(y)

or with a Where[...] type.

eb8680 commented 3 years ago

I like the idea of having assume or Where, where Where takes a base type and a dependent predicate. We could even allow their use simultaneously.

A downside of this notation is that it quickly becomes verbose when specifying complicated shape constraints. Maybe it would still be useful to have a notation like the original post specifically for Array that desugars to Dependent/Where. This also has the virtue of nudging users away from specifying shape arithmetic constraints which are harder to get right or compose.

For example, we could write a batched matmul as:

@function
def matmul(
    x: Array['real', (..., "a", "b")],
    y: Array['real', (..., "b", "c")],
) -> Array['real', (..., "a", "c")]:
    return x.matmul(y)

which might desugar to

@function
def matmul(
    x: Where[Dependent[lambda x: Array[x.dtype, x.shape]], lambda x: len(x.shape) >= 2],
    y: Where[Dependent[lambda y: Array[y.dtype, y.shape]], lambda y: len(y.shape) >= 2],
) -> Where[Dependent[lambda x, y: Array['real', broadcast_shape(x.shape[:-2], y.shape[:-2]) + (x.shape[-2], y.shape[-1])]], 
           lambda x, y: is_broadcastable(x.shape[:-2], y.shape[:-2]) and x.shape[-1] == y.shape[-2] and x.dtype == y.dtype]:  # generated from shape variables
    return x.matmul(y)

One implication of such an expressive design is that only ground types could be supported in inputs and outputs of actual funsors e.g. we couldn't allow the construction of funsor.Variable(name, Dependent[...]).

Another set of operations that pose a challenge for this design are those where shapes depend on an argument value, even when that value is guaranteed to be an integer or boolean literal. In #442, the function in Fresh takes domains as arguments rather than values. The simplest one is .sum(dim=dim):

@function
def sum_one_dim(
    x: Array,
    dim: int
) -> Dependent[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim+1:]]]:
    return x.sum(dim)

For this to work as written, lambda x, dim: ... would have to take the value of dim, not its type int. This forces us to choose a consistent behavior for the other argument x:

  1. Should this lambda take the value of x by default, rather than its .output? This does not seem ideal if x could be a Funsor
  2. Should we require a special notation indicating that the return type of sum_one_dim depends on the value of dim (e.g. dim: Value[int] instead of dim: int)?
  3. Should we have a different annotation ValueDependent for types that depend on values, e.g. the return type of sum_one_dim?
  4. Alternatively, should we simply disallow value-dependent types and say that sum_one_dim should either use a less specific signature or be written as a Funsor term?

Example of 2:

@function
def sum_one_dim(
    x: Array,
    dim: Value[int]
) -> Dependent[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim+1:]]]:
    return x.sum(dim)