pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Funsor function that can accept varied number of Bound variables #485

Open ordabayevy opened 3 years ago

ordabayevy commented 3 years ago

mean, var, and standardize functions in 3.1.6 Normalization layers accept multiple named axes (e.g., two batch,layer axes in BatchNorm and one layer axis in LayerNorm). How can I define Mean and Standardize below so that they can accept different number of bound variables?

@make_funsor
def Mean(
    X: Funsor,
    ax: Bound
) -> Fresh[lambda X: X]:
    return X.reduce(ops.add, ax) / ax.output.size

@make_funsor
def Standardize(
    X: Funsor,
    ax: Bound
) -> Fresh[lambda X: X]:
    return (X - Mean(X, ax)) / (Variance(X, ax) + ops.finfo(X.data).eps).sqrt()
eb8680 commented 3 years ago

This is not (yet) possible with the current implementation of make_funsor, but we'll need something like this if we want to rewrite more of funsor.terms with make_funsor.

A minimal solution would be to define a BoundSet hint

BoundSet = typing.FrozenSet[Bound]

and hard-code support for BoundSet inside make_funsor. Then we could write e.g.


@make_funsor
def Mean(
    X: Funsor,
    axes: BoundSet
) -> Fresh[lambda X: X]:
    return X.reduce(ops.add, axes) / reduce(ops.mul, [ax.output.size for ax in axes])