janosh / awesome-normalizing-flows

Awesome resources on normalizing flows.
MIT License
1.38k stars 122 forks source link

LAMPE: a PyTorch package for posterior estimation which implements normalizing flows #38

Closed francois-rozet closed 2 years ago

francois-rozet commented 2 years ago

Hello :wave:,

TLDR. The lampe package implements normalizing flows with PyTorch. I believe this is relevant for this collection. I hope you like it!

I'm a researcher interested in simulation-based inference and posterior estimation. I have written a low-level library for amortized posterior estimation called lampe. Initially, LAMPE was relying on nflows for its normalizing flows, but it quickly became a limitation. I was not happy with some of nflows design choices. For instance, it is only possible to sample or evaluate batches and most operators do not support broadcasting. It is also not possible to use other networks than the built-in ones. I considered contributing to nflows, but it seems the package is not actively developed anymore.

So I decided to implement my own normalizing flows within LAMPE. The goal was to rely as much as possible onto the already existing distributions and transformations of PyTorch. Unfortunately, PyTorch distributions and transforms are not modules, meaning that they don't implement a forward method, you cannot send the parameters to GPU with .to('cuda') or even get their parameters with .parameters(). To solve this problem, LAMPE defines two (abstract) classes: DistributionModule and TransformModule. The former is any nn.Module whose forward method returns a PyTorch Distribution. Similarly, the latter is any nn.Module whose forward method returns a PyTorch Transform. Then, what is a normalizing flow? It is simply a nn.Module that is constructed from a base DistributionModule and a list of TransformModule.

This design allows for very concise implementations of quite complex flows. Currently, LAMPE implements masked autoregressive flow (MAF), neural spline flow (NSF), neural autoregressive flow (NAF) and NAF based on unconstrained monotonic neural network (UMNN). All these flows support coupling (2 passes for inverse), fully autoregressive (as many passes as features) or anything in between (see Graphical Normalizing Flows). And all of that in about 800 lines of code, including whitespace and documentation. If you are interested, take a look at the transformations and flows.

Here is a small example with a neural spline flow (NSF).

>>> import lampe
>>> flow = lampe.nn.flows.NSF(7, context=16, transforms=3, hidden_features=[64] * 3, activation='ELU')
>>> flow
  (transforms): ModuleList(
    (0): SoftclipTransform(bound=5.0)
    (1): MaskedAutoregressiveTransform(
      (base): MonotonicRQSTransform(bins=8)
      (order): [0, 1, 2, 3, 4, 5, 6]
      (params): MaskedMLP(
        (0): MaskedLinear(in_features=23, out_features=64, bias=True)
        (1): ELU(alpha=1.0)
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ELU(alpha=1.0)
        (4): MaskedLinear(in_features=64, out_features=64, bias=True)
        (5): ELU(alpha=1.0)
        (6): MaskedLinear(in_features=64, out_features=161, bias=True)
    (2): MaskedAutoregressiveTransform(
      (base): MonotonicRQSTransform(bins=8)
      (order): [6, 5, 4, 3, 2, 1, 0]
      (params): MaskedMLP(
        (0): MaskedLinear(in_features=23, out_features=64, bias=True)
        (1): ELU(alpha=1.0)
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ELU(alpha=1.0)
        (4): MaskedLinear(in_features=64, out_features=64, bias=True)
        (5): ELU(alpha=1.0)
        (6): MaskedLinear(in_features=64, out_features=161, bias=True)
    (3): MaskedAutoregressiveTransform(
      (base): MonotonicRQSTransform(bins=8)
      (order): [0, 1, 2, 3, 4, 5, 6]
      (params): MaskedMLP(
        (0): MaskedLinear(in_features=23, out_features=64, bias=True)
        (1): ELU(alpha=1.0)
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ELU(alpha=1.0)
        (4): MaskedLinear(in_features=64, out_features=64, bias=True)
        (5): ELU(alpha=1.0)
        (6): MaskedLinear(in_features=64, out_features=161, bias=True)
    (4): Inverse(SoftclipTransform(bound=5.0))
  (base): DiagNormal(loc: torch.Size([7]), scale: torch.Size([7]))

The flow is currently a nn.Module. To condition the flow with respect to a context y, we call it. This returns a distribution which can be evaluated (log_prob) or sampled (sample) just like any torch distribution.

>>> y = torch.randn(16)
>>> conditioned = flow(y)
>>> conditioned.sample()
tensor([ 1.1381,  0.3619, -1.9963,  0.2681, -0.1613,  0.1885, -0.4108])
>>> conditioned.sample((5, 6)).shape
torch.Size([5, 6, 7])
>>> x = torch.randn(7)
>>> conditioned.log_prob(x)
tensor(-8.6289, grad_fn=<AddBackward0>)
>>> x = torch.randn(5, 6, 7)
>>> conditioned.log_prob(x).shape
torch.Size([5, 6])
janosh commented 2 years ago

@francois-rozet Thanks for bringing this up and explaining the design decisions! Looks like a very nice package. 👍 Happy to take a PR that adds lampe to data/packages.yml.

francois-rozet commented 2 years ago

Thanks! I'll do it right away!

francois-rozet commented 2 years ago

Done. By the way I see that Unconstrained Monotonic Neural Networks and Graphical Normalizing Flows are not referenced in the publication list. They are rather interesting!

janosh commented 2 years ago

Adding those would be great! And https://github.com/AWehenkel/UMNN under data/code.yml too. Would you like to submit another PR?

francois-rozet commented 2 years ago

@janosh I have submitted a PR with both papers.

francois-rozet commented 1 year ago

Hello @janosh 👋 The normalizing flow implementations within the lampe package have been exported to a standalone package called Zuko. I think the entry for lampe should be replaced by zuko. Should I send a PR to update this?

janosh commented 1 year ago

I think the entry for lampe should be replaced by zuko. Should I send a PR to update this?

Yes, sounds good! Feel free to mention and link lampe in the zuko description as a downstream tool and what it's used for.