pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Support ops.sum(data, dim=None, keepdims=False) #490

Closed fritzo closed 3 years ago

fritzo commented 3 years ago

Addresses #489 pair coded with @eb8680 @ordabayevy @fehiepsi

This demonstrates the new parametrized op syntax from #491 . The recipe is:

  1. add new *args, *kwargs to your op in funsor.ops.array
  2. add backend support in funsor.torch.ops and funsor.jax.ops
  3. implement find_domain(op, ...) in funsor.domains
  4. work around batch dimensions in funsor.tensor
  5. add some tests

Tested

fritzo commented 3 years ago

@ordabayevy I hope this can serve as a template for your #482 . Feel free to refactor after this PR merges, in case you'd like to reuse the logic e.g. in find_domain or eager.register. We may even be able to create a subclass ReductionOp(UnaryOp), replace to use

- @UnaryOp.make
+ @ReductionOp.make
  def sum(...):
      ...

and register ops.sum, ops.mean, etc. in a single pattern.