acids-ircam / pytorch_flows

Implementation and tutorials of normalizing flows with the novel distributions module
GNU General Public License v3.0
160 stars 17 forks source link

`flow.final_density.log_prob` strange behaviour #2

Open colobas opened 5 years ago

colobas commented 5 years ago

Hi @esling ,

First of all, thanks for sharing these notebooks, they're great.

I'm playing around with the classes you defined, and trying to understand why you're calculating densities manually (explicitly through the Change of Variables formula), rather than using flow.final_density.log_prob.

I thought these should yield the same result, but I tried it and they didn't:

Below are the definitions for the classes I used (mostly equal to yours, but I had to make a small change to the signature of the log_abs_det_jacobian method signature, because according to this , PyTorch expects inputs and ouputs with that method. I think the reason this isn't working is because I might have misunderstood this method.

Click to expand class definitions. ```python class Flow(transform.Transform, nn.Module): """ purpose of this class is to make `transform.Transform` 'trainable' simple flows will inherit it """ def __init__(self): transform.Transform.__init__(self) nn.Module.__init__(self) #self.bijective = True # Init all parameters def init_parameters(self): for param in self.parameters(): param.data.uniform_(-0.01, 0.01) # Hacky hash bypass def __hash__(self): return nn.Module.__hash__(self) # Flow version of Leaky ReLU class PReLUFlow(Flow): def __init__(self, dim): super(PReLUFlow, self).__init__() self.alpha = nn.Parameter(torch.Tensor([1])) self.bijective = True def init_parameters(self): for param in self.parameters(): param.data.uniform_(0.01, 0.99) def _call(self, z): return torch.where(z >= 0, z, torch.abs(self.alpha) * z) def _inverse(self, z): return torch.where(z >= 0, z, torch.abs(1. / self.alpha) * z) def log_abs_det_jacobian(self, z, y): """ I had to add a dummy "y" var to the method signature because pytorch expects it, as per: https://pytorch.org/docs/stable /distributions.html?highlight=distributions%20transforms#torch.distributions.transforms.Transform.log_abs_det_jacobian """ I = torch.ones_like(z) J = torch.where(z >= 0, I, self.alpha * I) log_abs_det = torch.log(torch.abs(J) + 1e-5) return torch.sum(log_abs_det, dim = 1) # Main class for normalizing flow class NormalizingFlow(nn.Module): def __init__(self, dim, blocks, flow_length, density): super().__init__() biject = [] for f in range(flow_length): for b_flow in blocks: biject.append(b_flow(dim)) self.transforms = transform.ComposeTransform(biject) self.bijectors = nn.ModuleList(biject) self.base_density = density self.final_density = distrib.TransformedDistribution(density, self.transforms) self.log_det = [] def forward(self, z): self.log_det = [] # Applies series of flows for b in range(len(self.bijectors)): y = self.bijectors[b](z) self.log_det.append(self.bijectors[b].log_abs_det_jacobian(z, y)) z = y return z, self.log_det ```

I then tried plotting the density of a simple NormalizingFlow using your approach in the notebooks, and then using flow.final_density.log_prob because I think they should yield the same, but they clearly don't:

Click to expand the definition of the `nflow_change_density` function. It's basically the same as your `change_density` function, but made to work on `NormalizingFlow` instances, rather than `Flow` instances. ```python def nflow_change_density(flow, z): """ changed this function to work on a `NormalizingFlow` instance rather than a `Flow` instance """ # Apply our transform on coordinates f_z, log_det = flow(torch.Tensor(z)) f_z = f_z.detach() log_det = log_det[0].detach() q0_density = flow.base_density.log_prob(torch.Tensor(z)).detach().exp() # Obtain our density q1_density = q0_density.squeeze() / np.exp(log_det.squeeze()) return q1_density, f_z ```

First using your approach:

nflow = NormalizingFlow(
    dim=2, 
    blocks=[PReLUFlow],
    flow_length=1,
    density=distrib.MultivariateNormal(torch.zeros(2), torch.eye(2))
)

nflow.bijectors[0].alpha.data = torch.Tensor([0.6])

q0_density = nflow.base_density.log_prob(torch.Tensor(z)).exp().detach()
q1_density, f_z = nflow_change_density(nflow, z)
>>> q1_density
tensor([4.9750e-08, 5.1368e-08, 5.3034e-08,  ..., 1.9093e-08, 1.8493e-08,
        1.7910e-08])
# Plot this
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(15, 5))
ax1.hexbin(z[:,0], z[:,1], C=q0_density.numpy().squeeze(), cmap='rainbow')
ax1.set_title('$q_0 = \mathcal{N}(\mathbf{0},\mathbb{I})$', fontsize=18);
ax2.hexbin(f_z[:,0], f_z[:,1], C=q1_density.numpy().squeeze(), cmap='rainbow')
ax2.set_title('$q_1=prelu(q_0)$', fontsize=18);

image

And now using nflow.final_density.log_prob :

f_z, _ = nflow(torch.Tensor(z))
f_z = f_z.detach()

q1_density = nflow.final_density.log_prob(f_z).detach()
>>>q1_density
tensor([510744.4062, 510744.4375, 510744.4688,  ..., 510744.4688,
        510744.4375, 510744.4062])

These are crazy high, compared to the values obtain via your approach, and they're log_probs, I didn't even .exp() them yet. And I won't because it'll blow up. I'm gonna plot them (and the gaussian is also in log_probs here):

# Plot this
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(15, 5))
ax1.hexbin(z[:,0], z[:,1], C=q0_density.log().numpy().squeeze(), cmap='rainbow')
ax1.set_title('$q_0 = \mathcal{N}(\mathbf{0},\mathbb{I})$', fontsize=18);
ax2.hexbin(f_z[:,0], f_z[:,1], C=q1_density.numpy().squeeze(), cmap='rainbow')
ax2.set_title('$q_1=prelu(q_0)$', fontsize=18);

image

I then decided to try to plot the log_probs using your approach, and I get this:

image

Where the shape is really similar to the one above (although they're orders of magnitude apart). But this gives me the feeling some constant is not being added/subtracted to the log_probs given by nflow.final_density.log_prob.

What am I missing you here? I'd really appreciate if you could give this some thought.

Thanks in advance!

colobas commented 5 years ago

So I found where the mismatch originates. It has to do with the order things are summed in PyTorch's transformed_distribution.TransformedDistribution.log_prob : https://pytorch.org/docs/stable/_modules/torch/distributions/transformed_distribution.html#TransformedDistribution.log_prob

More specifically it has to do with that _sum_righmost function, which depends on each individual transform's .event_dim, as well as the final distribution's .event_dim.

By almost randomly changing event_dim = len(self.event_shape) to event_dim = 0, I got it to yield the same result as your approach. So I knew I spotted where the error was. But event_dim = len(self.event_shape) = 1 looked correct, so I looked at the transform's event_dim, and noticed it was 0. After setting that to 1 on instantiation, it worked.

So now I'm thinking what's the best way to do this? How should we infer each transform's appropriate event_dim during the instantiation of a NormalizingFlow? If I understood the docs correctly, it should be 0 when operating on univariate distributions, 1 when operating on distributions over vectors, and 2 when operating on distributions over matrices. In that case, I think doing event_dim=1 is a good default