cnellington / Contextualized

An SKLearn-style toolbox for estimating and analyzing models, distributions, and functions with context-specific parameters.
http://contextualized.ml/
GNU General Public License v3.0
64 stars 7 forks source link

Bug with Shape of Rho in Contextualized Correlation Networks #228

Closed vexvexctor closed 4 months ago

vexvexctor commented 6 months ago

Hey,

I was just toying around with some of the code: dealing with the calculation of rho and rho^2, and when I tried to print the shape of Rho, I got an output of () only. But when I did it with rho^2 it worked normally

Thanks, Vexvexctor

blengerich commented 6 months ago

Hi @vexvexctor , thanks for the bug report. Could you please provide a minimal working example to reproduce the bug?

cnellington commented 5 months ago

To reproduce

import numpy as np
from contextualized.easy import ContextualizedCorrelationNetworks
C = np.zeros((100, 2))
X = np.zeros((100, 4))
model = ContextualizedCorrelationNetworks()
model.fit(C, X, max_epochs=1)
rho = model.predict_correlation(self.C, individual_preds=True, squared=False)
assert rho.shape == (1, self.n_samples, self.x_dim, self.x_dim)
rho = model.predict_correlation(self.C, individual_preds=False, squared=False)
assert rho.shape == (self.n_samples, self.x_dim, self.x_dim)  # This fails