jjcmoon / hardness-nesy

On the Hardness of Probabilistic Neurosymbolic Learning (ICML2024)
https://arxiv.org/pdf/2406.04472
MIT License
3 stars 0 forks source link

How is WeightME implemented? #1

Closed hehyuan closed 2 months ago

hehyuan commented 4 months ago

Hi, nice work! I have read your solid paper, and I am interested in your work. Since I have read your code, I am still confused about how WeightME is implemented.

If I have understand correct, the core implementation is in solvers/cms_gen.py, and in the function below:

def cnf_val(self, cnf):
        # convert to cmsgen format
        s = pycmsgen.Solver()
        s.add_clauses(cnf.clauses)
        for v in cnf.binary_vars:
            s.set_var_weight(v+1, float(cnf.weights[v, 1].exp()))

        # sample models
        models = list()
        for _ in range(self.nb_samples):
            sat, _ = s.solve(time_limit=30)
            if not sat:
                return torch.tensor(0.0)

            model = tuple(s.get_model())
            # s.add_clause([-l for l in model])
            models.append(model)

        print(f"FOUND {len(models)} / {self.nb_samples} MODELS ({len(set(models))} unique)")

        # evaluate models
        total = sum(cnf.get_lit_weights(model).sum() for model in models)        
        return total

Could you further explain how the above function is related with Definition 5.1 in your paper?

Best regards!

jjcmoon commented 4 months ago

WeightME is implemented in just 1 line (below # evaluate models), the first part of the function samples the models. The code indeed looks a little different from the equation in Def. 5.1. The reason for this is that the code doesn't directly calculate WeightME, but will generate the correct WeightME gradients when PyTorch backprops through it. (This is quite standard for gradient estimators, e.g. look at REINFORCE implementations.) Note that get_lit_weights returns log-probabilities, so when the PyTorch backprops through these you get $\nabla \log w(x) = 1/w(x)$.

Hope this makes it more clear.