didriknielsen / survae_flows

Code for paper "SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows"
MIT License
283 stars 34 forks source link

CIF flow implementation #14

Closed SamGalanakis closed 3 years ago

SamGalanakis commented 3 years ago

Hello,

Based on the paper CIF blocks can be implemented using Augment -> Bijection -> Slice. In this case, is the bijection done in the augmented (higher) dimensionality (concatenated Input+Noise) or done in the Input dimensionality conditioned on the Noise? From the CIF paper it seems to be the latter and I can't seem to see a straightforward way of implementing this with the current code. Hope that wasn't too unclear!

didriknielsen commented 3 years ago

Hi, thanks for your question!

I believe it can be implemented something like:

class CIFLayer(SequentialTransform):
    def __init__(self, dim, aug_dim, net_q, net_f, net_p):
        super().__init__([
            Augment(ConditionalNormal(net_q), x_size=dim),
            Reverse(dim+aug_dim),
            AffineCouplingBijection(net_f, num_condition=aug_dim),
            Reverse(dim+aug_dim),
            Slice(ConditionalNormal(net_p), num_keep=dim),
        ])

Note that the Reverse layers are needed due to the implementations of Augment appending on the right and AffineCouplingBijection conditioning on the left.

Here is a fully runnable example:

from survae.flows import Flow
from survae.distributions import StandardNormal, ConditionalNormal
from survae.transforms import SequentialTransform, Augment, Slice, Reverse, AffineCouplingBijection
from survae.nn.nets import MLP

class CIFLayer(SequentialTransform):
    def __init__(self, dim, aug_dim, net_q, net_f, net_p):
        super().__init__([
            Augment(ConditionalNormal(net_q), x_size=dim),
            Reverse(dim+aug_dim),
            AffineCouplingBijection(net_f, num_condition=aug_dim),
            Reverse(dim+aug_dim),
            Slice(ConditionalNormal(net_p), num_keep=dim),
        ])

dim = 24
aug_dim = 6

model = Flow(base_dist=StandardNormal((dim,)),
             transforms=CIFLayer(
                dim=dim, aug_dim=aug_dim,
                net_q=MLP(dim,2*aug_dim,[200,100]),
                net_f=MLP(aug_dim,2*dim,[200,100], out_lambda=lambda x: x.reshape(-1,dim,2)),
                net_p=MLP(dim,2*aug_dim,[200,100]),
             ))

x = model.sample(5)
print(x.shape, x.min(), x.max())
print(model.log_prob(x))

Let me know if you have any issues, happy to help!

SamGalanakis commented 3 years ago

Thanks for the detailed reply, was exactly what I needed. May be worth adding it to the examples too.