swiseman / bethe-min

7 stars 2 forks source link

How to use `exact_marginals` function? #1

Open TFbruv opened 4 years ago

TFbruv commented 4 years ago

Hi, can you please tell me how to use the exact_marginals function to generate the node and paiwsise marginals of a 15x15 Ising model as described in the paper?

yoonkim commented 4 years ago

Hi, I've uploaded code for the Ising experiments. Hope this helps!

TFbruv commented 4 years ago

The marginals given by ising.py are different from marginals calculated by brute force enumeration. They are expected to be equal right? I have verified for n = 3, 4, 5.

yoonkim commented 4 years ago

Right, would you mind posting the code that tests this?

yoonkim commented 4 years ago

Here's a small script I just wrote to test the unary marginals, and they seem to match. Hope this helps!

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import math
import ising as ising_models

def logsumexp(x, dim=1):
    d = torch.max(x, dim)[0]
    if x.dim() == 1:
        return torch.log(torch.exp(x - d).sum(dim)) + d
    else:
        return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d

def log_energy(x, unary, binary):
    unary_x = x * unary.unsqueeze(0) # b x n**2                                                                                                                                                                    
    binary_x = torch.matmul(x, binary)*x # b x n**2                                                                                                                                                                
    return (unary_x + binary_x).sum(1)

def test():
    torch.manual_seed(3435)
    for n in [3, 4, 5]:
        print("")
        print("testing for grid size: %d" % n)
        print("")
        ising = ising_models.Ising(n)

        # exact marginals from variable elimination                                                                                                                                                                
        log_Z_ve = ising.log_partition_ve()
        unary_marginals_ve, binary_marginals_ve = ising.marginals()
        # enumeration                                                                                                                                                                                              
        x = np.array(list(map(list, itertools.product([-1.,1.], repeat = n**2))))
        x = x.reshape(2**(n**2), n**2)
        x = torch.from_numpy(x).type_as(ising.unary)
        x_scores = log_energy(x, ising.unary, ising.binary*ising.mask)
        log_Z_enum = logsumexp(x_scores, dim = 0)
        px_enum = (x_scores - log_Z_enum).exp()
        x_binary = (x + 1)*0.5
        # test unary marginals                                                                                                                                                                                     
        for i in range(n**2):
            x_i = x_binary[:, i]
            px_i = (px_enum*x_i).sum().item()
            print(i, px_i, unary_marginals_ve[i].item())

if __name__ == '__main__':
    test()