braun-steven / simple-einet

An implementation of EinsumNetworks in PyTorch.
MIT License
20 stars 8 forks source link

Conditional reasoning examples #8

Closed ghost closed 10 months ago

ghost commented 10 months ago

Hi Again,

The examples applications seem to be geared towards either generative modelling for sampling and image completion or discriminative learning for inference.

I was wondering if you would be happy to provide an example for generative modelling for conditional inference similar to conditional inference in Bayesian networks. And perhaps a simple application such as Iris classification would suffice.

Thanks,

braun-steven commented 10 months ago

Hey iivevo,

would you mind giving an example? Do you mean, maximizing e.g. $p(X=x | Y=y, Z=z)$ instead of $p(X=x, Y=y, Z=z)$?

Cheers, Steven

ghost commented 10 months ago

I meant using the generatively trained model such as P(X,Y,Z), I would like to perform any conditional query i.e. P(X,Y| Z)

From reading Sum product network and Probabilistic circuit literature I would think this would require two passes, where in one of the passes we compute the marginal for Z

But are you saying that you can maximize P(X|Y,Z) directly to perform MAP inference?

Also using this setup would it be possible to perform the query P(class | iris features) as an example of a model trained generatively to classify the iris dataset? I would assume we would need train a model such P(class, feature_1,..feature_n)

braun-steven commented 10 months ago

Got it! This is already supported.

From reading Sum product network and Probabilistic circuit literature I would think this would require two passes, where in one of the passes we compute the marginal for Z

Exactly, PCs provide tractable marginalization, that is we can compute any quantity $p(Q | E)$ efficiently via two forward passes $p(Q | E) = p(Q, E) / p(E)$.

Also using this setup would it be possible to perform the query P(class | iris features) as an example of a model trained generatively to classify the iris dataset? I would assume we would need train a model such P(class, feature_1,..feature_n)

Correct! Below you can find an example of this exact scenario:

import torch
from simple_einet.layers.distributions.normal import Normal
from simple_einet.layers.distributions.categorical import Categorical
from simple_einet.einet import Einet
from simple_einet.einet import EinetConfig
from simple_einet.layers.distributions.multidistribution import MultiDistributionLayer
from sklearn.datasets import load_iris

if __name__ == "__main__":
    torch.manual_seed(0)

    # Load iris dataset
    # Random variables: sepal_length (SL), sepal_width (SW), petal_length (PL), petal_width (PL), class (C)
    features, label = load_iris(return_X_y=True)
    features = torch.tensor(features, dtype=torch.float32)
    label = torch.tensor(label, dtype=torch.int)
    data = torch.cat([features, label.unsqueeze(1).float()], dim=1).unsqueeze(1)
    data_without_class = data.clone()
    data_without_class[..., -1] = torch.nan  # Set to missing (nan)

    # Construct Einet
    cfg = EinetConfig(
        num_features=data.shape[-1],
        depth=2,
        num_sums=5,
        num_channels=1,
        num_leaves=10,
        num_repetitions=3,
        num_classes=1,
        dropout=0.0,
        leaf_type=MultiDistributionLayer,
        leaf_kwargs={
            "scopes_to_dist": [
                ([0, 1, 2, 3], Normal, {}),  # Feature Random Variables {SL, SW, PL, PW}
                ([4], Categorical, {"num_bins": 3}),  # Class Random Variable {C}
            ]
        },
    )
    einet = Einet(cfg)

    # Optimize Einet parameters (weights and leaf params)
    optim = torch.optim.Adam(einet.parameters(), lr=0.01)

    for i in range(1000):
        optim.zero_grad()

        # Compute joint log-likelihoods (full-batch): log(p(SL, SW, PL, PW, C))
        lls = einet(data)

        # Backprop negative log-likelihood loss
        nlls = -1 * lls.sum()
        nlls.backward()

        # Update weights
        optim.step()

        if i % 100 == 0:
            # Evaluate accuracy: do mpe on p(C | SL, SW, PL, PW)
            mpe_result = einet.mpe(evidence=data_without_class, marginalized_scopes=[4])
            cls = mpe_result[..., -1].squeeze().int()
            acc = (cls == label).float().mean() * 100
            print(f"Epoch: {i:5}, log-likelihood: {lls.sum().item():10.4f}, Accuracy: {acc.item():8.4f}")

    # Compute conditional log-likeihoods log(p(SL, SW | PL, PW)) = log(p(SL, SW, PL, PW)) - log(p(PL, PW))
    lls_sl_sw_pl_pw = einet(data, marginalized_scopes=[4])  # log(p(SL, SW, PL, PW)), marginalize class RV (index 4)
    lls_pl_pw = einet(data, marginalized_scopes=[0, 1, 4])  # log(p(PL, PW))
    lls_conditional = lls_sl_sw_pl_pw - lls_pl_pw  # log(p(SL, SW | PL, PW))
    print(f"Conditional log-likelihoods log(p(SL, SW | PL, PW)): {lls_conditional.sum().item():.4f}")

Output:

Epoch:     0, log-likelihood: -2287.8340, Accuracy:  33.3333
Epoch:   100, log-likelihood: -1111.5886, Accuracy:  79.3333
Epoch:   200, log-likelihood:  -736.6158, Accuracy:  96.0000
Epoch:   300, log-likelihood:  -554.5394, Accuracy:  96.6667
Epoch:   400, log-likelihood:  -385.2840, Accuracy:  96.6667
Epoch:   500, log-likelihood:  -338.9807, Accuracy:  97.3333
Epoch:   600, log-likelihood:  -276.7669, Accuracy:  96.6667
Epoch:   700, log-likelihood:  -199.5139, Accuracy:  96.6667
Epoch:   800, log-likelihood:  -283.7740, Accuracy:  96.6667
Epoch:   900, log-likelihood:  -274.0552, Accuracy:  96.6667
Conditional log-likelihoods log(p(SL, SW | PL, PW)): -127.0728

(Note, that I had to add a small fix in https://github.com/braun-steven/simple-einet/commit/c59971d46667fce3c863558469b1718ef0d04850 for this to work, so keep in mind to sync the repo if you want to replicate the above.)

But are you saying that you can maximize P(X|Y,Z) directly [...]?

Yes, instead of maximizing $p(X, Y, Z)$ you can also use any other quantity, e.g. the conditional $p(X| Y, Z)$ to optimize the model parameters.

ghost commented 10 months ago

Hi Steven,

This is great, thanks for the code sample. I think this demonstrates the power of PCs as an alternative to graphical models.

I noticed that you used MPE to do inference rather than using the conditional log-likelihoods directly.

I am not too familiar with the term MPE - is this a way of finding which class index produces the highest log likelihood? So I guess if that's the case it would avoid the need to compute the log likelihood separately for each class index? And is this always tractable?

I see now that MPE is a nice way of doing argmax_C(log(P(C=1|features), ... , log(P(C=n|features))

Thanks again,

braun-steven commented 10 months ago

Correct, MPE infers the most probable state for a subset of RVs. If you model the class RV explicitly as leaves, as we did in the above example, you can infer the class variable's most probable state, given all other RVs (i.e., the feature RVs).

Another way is to have $C$ root nodes in the PC over the full joint distribution, i.e. $p_1(X), \dots, p_C(X)$. A forward pass of the PC would give you implicit class-conditional likelihoods $p_1(X | C=1), \dots, p_C(X |C=c)$ which you can transform into class-likelihoods using Bayes-rule $p_c(C=c | X) = \frac{p_c(X | C=c)p(C=c)}{\sum_c p_c(X | C=c)p(C=c)}$. This way you can optimize the PC to minimize the cross-entropy in a discriminative fashion. The Classifying the Iris Dataset with Einets notebook shows you how this is done.