Open ordabayevy opened 3 years ago
This is not (yet) possible with the current implementation of make_funsor
, but we'll need something like this if we want to rewrite more of funsor.terms
with make_funsor
.
A minimal solution would be to define a BoundSet
hint
BoundSet = typing.FrozenSet[Bound]
and hard-code support for BoundSet
inside make_funsor
. Then we could write e.g.
@make_funsor
def Mean(
X: Funsor,
axes: BoundSet
) -> Fresh[lambda X: X]:
return X.reduce(ops.add, axes) / reduce(ops.mul, [ax.output.size for ax in axes])
mean
,var
, andstandardize
functions in 3.1.6 Normalization layers accept multiple named axes (e.g., twobatch,layer
axes inBatchNorm
and onelayer
axis inLayerNorm
). How can I defineMean
andStandardize
below so that they can accept different number of bound variables?