pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.51k stars 984 forks source link

[FR] Simple arithmetic transformation operator overloading #2441

Closed ecotner closed 4 years ago

ecotner commented 4 years ago

Hi, I've been trying to teach myself pyro for the last couple weeks, and it is a great package; I'm learning quite a bit about Bayesian methods from going through the tutorials and doing side projects. In the course of my studies, I found that there is one thing that would make creating models with deterministic transformations even more transparent and simple: overloading simple arithmetic operators on Distributions so that you can quickly shift/scale/invert them. I'll give you an example:

Say I have a gamma-distributed RV that I want to shift in the positive direction; in pyro I would have to make that distribution like so:

from pyro.distributions import Gamma, TransformedDistribution
from pyro.distributions.transforms AffineTransform

X = Gamma(1, 1)
X = TransformedDistribution(X, AffineTransform(1, 1))

it actually took me quite a long time to figure out that this was even possible, and it would be much simpler if I could do a short one-liner like X = 1 + Gamma(1, 1), but this throws an error because addition between distributions and integers isn't defined. This extends to other simple binary/unary operations such as subtraction, multiplication, division, negation, etc. I believe pymc3 already has this kind of feature built in, so it would be great if we could do this with pyro distributions too.

I was playing around with some mixin class that would make this sort of thing more transparent (I'm not sure how you'd implement this in reality; maybe directly add the dunder methods to Distribution so other subclasses can inherit from it?):

class ArithmeticMixin:
    def __add__(self, x: float):
        return _TransformedDistribution(self, AffineTransform(x, 1))

    def __radd__(self, x: float):
        return _TransformedDistribution(self, AffineTransform(x, 1))

    def __sub__(self, x: float):
        return _TransformedDistribution(self, AffineTransform(-x, 1))

    def __rsub__(self, x: float):
        return _TransformedDistribution(self, AffineTransform(x, -1))

    def __mul__(self, x: float):
        return _TransformedDistribution(self, AffineTransform(0, x))

    def __rmul__(self, x: float):
        return _TransformedDistribution(self, AffineTransform(0, x))

    def __truediv__(self, x: float):
        return _TransformedDistribution(self, AffineTransform(0, 1/x))

    def __neg__(self):
        return _TransformedDistribution(self, AffineTransform(0, -1))

    def __abs__(self):
        return _TransformedDistribution(self, AbsTransform())

class _TransformedDistribution(TransformedDistribution, ArithmeticMixin):
    pass

class _Gamma(Gamma, ArithmeticMixin):
    pass

This allows simple and transparent operations such as

x = _Gamma(1, 1)
x = x + 1
x = 2*x
x = 1-x
x = x/3.14159
x = abs(x)
x.sample([5])  # tensor([1.2474, 0.5195, 0.5959, 3.3581, 0.4880])

Would this be a feature worthy of adding to pyro? Or should I open this issue directly with pytorch, since it seems like the distributions module sources a lot of code from them?

fritzo commented 4 years ago

Interesting idea, @ecotner! I think we could add this to Pyro as a RandomVariable class. Moreover we could use the metaclass.__getitem__ trick (as used by PyroModule) to make it easy to promote Distribution objects to RandomVariable objects (via mixin):

from pyro.distributions import RandomVariable as RV

x = RV[dist.Gamma](1, 1)
x = x + 1
...
x.sample([5])

If you want to put up a PR adding a RandomVariable mixin in something like pyro/distributions/random_variable.py, I could help or contribute the __getitem__ trick. If this seems to work out well in Pyro, we could move it upstream to torch.distributions. WDYT?

Longer term I think we'd avoid adding this behavior by default. There are actually two possible interpretations of arithmetical operations on distributions. Viewing distributions in the mathematical sense, they are generalizations of density functions. Thus a natural interpretation of arithmetical operations on distributions is as pointwise operations on densities. This is reverse variance of the random variable interpretation, which applies operations to the inputs rather than the outputs of those density functions. And Pyro is actually moving to integrate these pointwise operations more deeply, via the Funsor library. When that big refactoring goes through then Pyro will be much more able to automatically detect conjugacy relationships and perform partially analytic inference, but it may introduce the pointwise operations as the default interpretation of mathematical operators, conflicting with the RandomVariable interpretation. @eb8680 has thought a lot about making the two interpretations play well together https://github.com/pyro-ppl/funsor/pull/130

ecotner commented 4 years ago

@fritzo yes, I definitely had the random variable interpretation of distributions in mind when suggesting it! And I can understand why you would want to separate the two**.

And yes, I'd be happy to contribute a PR! I'll get started on that soon.

**This is potentially off-topic, but if you wanted to apply pointwise transformations to densities, wouldn't the vast majority of transformations result in non-normalized density functions? E.g. if f(x) was a properly normalized density, h(x) = f(x) + g(x) would not be normalized anymore unless \int dx g(x) = 0 (and god forbid you have to normalize g(f(x))). Wouldn't then the random variable interpretation of distributions and transformations on those distributions be more natural? Most often in my experience it is more common to consider operations like X1 + X2 where the X's are RV's rather than f1(x) + f2(x), where the f's are the densities underlying those RV's. Maybe if you were trying to construct some custom prior/variational distribution from more elementary ones perhaps... Just my $0.02 though; far be it from me to tell you how to structure your library though :sweat_smile: I'll read the funsor paper; maybe there's some concept I'm missing.

fritzo commented 4 years ago

@stefanwebb do you have any opinion on the RV interface? I know you've spent a lot of time thinking about the language of normalizing flows.

stefanwebb commented 4 years ago

I agree with your comments, Fritz - I think it would be better to have a random variable object that represents the variable rather than the distribution, and perform operations on this

stefanwebb commented 4 years ago

I hadn't thought about structuring things in terms of a random variable abstraction, although this seems more natural to write the normalizing flows library. For instance, a transform should input a random variable and return a random variable

fritzo commented 4 years ago

@ecotner Maybe the safest design choice is to create a new RandomVariable class that does not inherit from TorchDistribution, i.e. is a container class rather than a mixin. I think that would give us a little more flexibility, not needing to avoid conflict with existing Distribution methods.

eb8680 commented 4 years ago

@ecotner the + operation in Funsor usually means a sum of log-densities, i.e. a product of densities. It's useful to work with distributions rather than random variables when writing inference code because that's the typical setting for specifying inference algorithms. When every random variable in a model has a density wrt the same base measure, the program's product measure has a joint density that can be written as a factor graph.

I don't think a new RandomVariable object is a great long-term solution because most operations on random variables are not invertible and don't have obvious computational interpretations other than constructing a lazy expression that can be .sample()d. You're basically going to end up duplicating a lot of the code for sampling, shape inference and inversion that's already in Funsor or in PyTorch's new LazyTensor.

eb8680 commented 4 years ago

My preference would be to follow the idea in section 3 of the Funsor paper and have pyro.sample sites (or RandomVariable objects) emit funsor.Variable objects, which transform like random variables. I suspect you'll come to similar design conclusions if you spend some time trying to give a semantics to expressions that contain multiple dependent RandomVariables.

ecotner commented 4 years ago

@fritzo @stefanwebb @eb8680 I made a fairly simple wrapper/container around Distribution that allows you to easily apply arithmetic operations on the RandomVariable object (and other transformations parametrized by Transform). It's in PR #2448 if you want to give any critique.

fritzo commented 4 years ago

@ecotner

Most often in my experience it is more common to consider [random variable operations]

Let me expand on @eb8680's explanation of inference operations. In the funsor paper we're using two classes of operations: "modeling side" operations in user-facing inference code (like the random variable operations you commonly see), and "inference side" operations that are used in inference code (e.g. inside the guts of TraceEnum_ELBO). For example in Figure 2. image you can see "model side" operation y <- exp(z) on the left, and an "inference side" operation p <- p x Gaussian(...) on the right. Our future plans in Pyro are to make good use of both types of operations through the Funsor library. The way we're planning to separate the two types of operations is to allow "model side" operations on the return values of pyro.sample statements and to allow "inference side" operations on Distribution objects. Note that whereas Pyro 1.x only returns torch.Tensor objects from pyro.sample statements, our plans for Funsor-based Pyro are to also return random-variable like objects (@eb8680's Variable) that support exactly the "model side" operations you are suggesting.

I hope this explains our concern with adding "model side" operations to Distribution objects. I still think it would be fine to add your operations if we clearly gate them behind a .rv <--> .dist partition as I've suggested in your PR #2448. Alternatively, we'd be happy to help spin you up on the Funsor library in case you'd like to contribute there (we see it as the future of both Pyro and our JAX-based NumPyro).

ecotner commented 4 years ago

@fritzo I'm not an expert in inference algorithms, but I think I understand what you're getting at. Inference requires direct manipulation of the densities, and models (typically) require transformations of random variables, and these things should be kept distinct.

But if you already have plans for a Variable object, do you even need my RandomVariable, or would that basically be adding a second API for something that already exists?

And I'd love to contribute, but I'm not sure how; I'm more of an observer/enthusiast when it comes to this stuff; my background is in (theoretical) physics, which is notoriously devoid of statistical rigor haha. I'm always down to learn though.

fritzo commented 4 years ago

do you even need my RandomVariable, or would that basically be adding a second API ...?

I think you found a slick API that I'd be happy to have in Pyro. The Variable interface is still a ways off, and even once that exists I think your syntax would be a great way to write unit tests for the newer Variable syntax.

fritzo commented 4 years ago

Implemented in #2448