cnellington / Contextualized

An SKLearn-style toolbox for estimating and analyzing models, distributions, and functions with context-specific parameters.
http://contextualized.ml/
GNU General Public License v3.0
65 stars 9 forks source link

use dag_pred_np for bn errors #211

Closed cnellington closed 1 year ago

cnellington commented 1 year ago

Addresses #208 , although upon further inspection the error doesn't seem to be scaling. At least here we can be confident we're predicting the same during training and inference. Visual inspection confirms dag_pred_np and dag_pred perform equivallently.

>>> from contextualized.dags.graph_utils import dag_pred, dag_pred_np
>>> import torch
>>> import numpy as np
>>> X = np.random.normal(0, 1, (5, 3))
>>> W = np.random.normal(0, 1, (5, 3, 3))
>>> np_pred = dag_pred_np(X, W)
>>> torch_pred = dag_pred(torch.Tensor(X), torch.Tensor(W))
>>> torch_pred[0]
tensor([0.1719, 1.9577, 0.6442])
>>> np_pred[0]
array([0.17191051, 1.95765787, 0.6441674 ])
>>> torch_pred[1]
tensor([-2.8746,  0.9607, -0.9306])
>>> np_pred[1]
array([-2.87464662,  0.96066154, -0.93063571])
>>> torch_pred[2]
tensor([ 7.1357, -1.1330, -4.0369])
>>> np_pred[2]
array([ 7.1356685 , -1.13298522, -4.03687548])
>>> exit()