Open Joker222565 opened 7 months ago
import random import urllib.request from os import path
import pandas as pd from graphviz import Digraph from rpy2.robjects import pandas2ri from rpy2.robjects.packages import importr
from lib.evaluation import f1 from lib.mmhc import mmhc
pandas2ri.activate() base, bnlearn = importr('base'), importr('bnlearn')
network = 'ecoli70' network_path = 'Input/' + network + '.rds' if not path.isfile(network_path): url = 'https://www.bnlearn.com/bnrepository/' + network + '/' + network + '.rds' urllib.request.urlretrieve(url, network_path) dag_true = base.readRDS(network_path)
datasize = 1000 filename = 'Input/' + network + '_' + str(datasize) + '.csv' if path.isfile(filename): data = pd.read_csv(filename, dtype='float64') # change dtype = 'float64'/'category' if data is continuous/categorical else: data = bnlearn.rbn(dag_true, datasize) data = pd.DataFrame(data) data = data[random.sample(list(data.columns), data.shape[1])] data.to_csv(filename, index=False)
dag_learned = mmhc(data)
dot = Digraph() for node in bnlearn.nodes(dag_learned): dot.node(node) for parent in bnlearn.parents(daglearned, node): dot.edge(node, parent) dot.render('output/' + network + '' + str(datasize) + '.gv', view = False)
print('f1 score is ' + str(f1(dag_true, dag_learned))) print('shd score is ' + str(bnlearn.shd(bnlearn.cpdag(dag_true), dag_learned)[0]))
The same situation. May I ask do YOU fix the bug?