didriknielsen / survae_flows

Code for paper "SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows"
MIT License
283 stars 34 forks source link

Does Vanilla Conv1x1 layer has guarantee on invertible? #16

Open pkulwj1994 opened 3 years ago

pkulwj1994 commented 3 years ago

I am reading and reimplementing your great codebase, Didrick. This repo is such a nice work on Normalizing Flow generative experiments!

In vanilla implementation of bijection transform conv1x1, I find that the conv1x1 kernel (weight) is not strictly invertible with guarantee, which may also lead the torch.slogdet() function to diverse(use of svd for non-invertible kernel matrix) and result in Error. Same problem also arise from Glow paper but author does not give refinement method for non-invertible kernel. My question is: have you any refinements of conv1x1 kernel, just like refinements planar flow and sylvester flow gives to force transform to be invertible all the time. I think the use of LU or Householder tricks for kernel will retrict the power of Conv1x1 transform and thus not applausive.

Codes for Conv1x1 layer is copied below:

class Conv1x1(Bijection):

    def __init__(self, num_channels, orthogonal_init=True, slogdet_cpu=True):
        super(Conv1x1, self).__init__()

        self.num_channels = num_channels
        self.slogdet_cpu = slogdet_cpu
        self.weight = nn.Parameter(torch.Tensor(num_channels,num_channels))
        self.reset_parameters(orthogonal_init)

    def reset_parameters(self, orthogonal_init):

        self.orthogonal_init = orthogonal_init

        if self.orthogonal_init:
            nn.init.orthogonal_(self.weight)
        else:
            bound = 1.0/ np.sqrt(self.num_channels)
            nn.init.uniform_(self.weight, -bound, bound)

    def _conv(self,weight, v):
        _,channel, *features = v.shape
        n_feature_dims = len(features)

        fill = (1,)*n_feature_dims
        weight = weight.view(channel, channel, *fill)

        if n_feature_dims == 1:
            return F.conv1d(v,weight)
        elif n_feature_dims == 2:
            return F.conv2d(v,weight)
        elif n_feature_dims == 3:
            return F.conv3d(v,weight)
        else:
            raise ValueError(f'Got {n_feature_dims}d tensor, expected 1d, 2d, or 3d')

    def _logdet(self, x_shape):
        b,c,*dims = x_shape
        if self.slogdet_cpu:
            _, ldj_per_pixel = torch.slogdet(self.weight.to('cpu'))
        else:
            _,ldj_per_pixel = torch.slogdet(self.weight)
        ldj = ldj_per_pixel * reduce(mul, dims)
        return ldj.expand([b]).to(self.weight.device)

    def forward(self,x):
        z = self._conv(self.weight,x)
        ldj = self._logdet(x.shape)

        return z,ldj

    def inverse(self,z):
        weight_inv = torch.inverse(self.weight)
        x = self._conv(weight_inv, z)
        return x

Great Thanks Didrik!

didriknielsen commented 3 years ago

Hi again,

In theory, the layer might become non-invertible, but it will be very unlikely to happen (won't really happen in practice). If you get an error I would suspect that this is due to NANs appearing somewhere else which tends to crash the slogdet function. One possibility is of course parameterize the matrix differently to ensure that it is truly impossible for it to become non-invertible.

Hope that helps, and thanks for your question!