karlstratos / doe

Difference-of-Entropies (DoE) Estimator
MIT License
24 stars 4 forks source link

Different dimensions #2

Open link-er opened 3 months ago

link-er commented 3 months ago

Hey,

I was wondering what should I change in code in order to be able to estimate MI between two variables of different dimensionality?

link-er commented 3 months ago

Does this code modification look correct?

class DoE(nn.Module):
    def __init__(self, dimX, dimY, hidden, layers, pdf):
        super(DoE, self).__init__()
        self.qY = PDF(dimY, pdf)
        self.qY_X = ConditionalPDF(dimX, dimY, hidden, layers, pdf)

    def forward(self, X, Y):
        hY = self.qY(Y)
        hY_X = self.qY_X(Y, X)

        loss = hY + hY_X
        mi_loss = hY_X - hY
        return (mi_loss - loss).detach() + loss

class ConditionalPDF(nn.Module):
    def __init__(self, dimX, dimY, hidden, layers, pdf):
        super(ConditionalPDF, self).__init__()
        assert pdf in {'gauss', 'logistic'}
        self.dimX = dimX
        self.dimY = dimY
        self.pdf = pdf
        self.X2Y = FF(dimX, hidden, 2*dimY, layers)

    def forward(self, Y, X):
        mu, ln_var = torch.split(self.X2Y(X), self.dimY, dim=1)
        cross_entropy = compute_negative_ln_prob(Y, mu, ln_var, self.pdf)
        return cross_entropy

class PDF(nn.Module):
    def __init__(self, dimY, pdf):
        super(PDF, self).__init__()
        assert pdf in {'gauss', 'logistic'}
        self.dimY = dimY
        self.pdf = pdf
        self.mu = nn.Embedding(1, self.dimY)
        self.ln_var = nn.Embedding(1, self.dimY)  # ln(s) in logistic

    def forward(self, Y):
        cross_entropy = compute_negative_ln_prob(Y, self.mu.weight,
                                                 self.ln_var.weight, self.pdf)
        return cross_entropy