FenTechSolutions / CausalDiscoveryToolbox

Package for causal inference in graphs and in the pairwise settings. Tools for graph structure recovery and dependencies are included.
https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/index.html
MIT License
1.08k stars 198 forks source link

GNN never stop even P-value < 0.01 #34

Closed huangwei2913 closed 4 years ago

huangwei2913 commented 4 years ago

I have a tesla GPU with cuda installed : +-----------------------------------------------------------------------------+ | NVIDIA-SMI 418.87.00 Driver Version: 418.87.00 CUDA Version: 10.1 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla P100-PCIE... On | 00000000:65:00.0 Off | Off | | N/A 83C P0 43W / 250W | 1087MiB / 16280MiB | 100% Default | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | 0 2648 C python 1077MiB |

My question is the GNN would never stop even I the smallest P-value obtained, what is wrong with my code, please help from cdt.independence.graph import FSGNN Fsgnn = FSGNN(train_epochs=100, test_epochs=50, l1=0.1, batch_size=1000)

start_time = time.time()
ugraph = Fsgnn.predict(df, threshold=1e-7)
print("--- Execution time : %4.4s seconds ---" % (time.time() - start_time))
nx.draw_networkx(ugraph, font_size=8)  # The plot function allows for quick visualization of the graph.
# plt.show()
# List results
list00 = pd.DataFrame(list(ugraph.edges(data='weight')))

from cdt.causality.graph import CGNN

Cgnn = CGNN(nruns=16, train_epochs=200, test_epochs=100, batch_size=1000)
start_time = time.time()
dgraph = Cgnn.orient_undirected_graph(df, ugraph)
print("--- Execution time : %4.4s seconds ---" % (time.time() - start_time))

# Plot the output graph
nx.draw_networkx(dgraph, font_size=8)  # The plot function allows for quick visualization of the graph.
# plt.show()
# Print output results :
list22 = pd.DataFrame(list(dgraph.edges(data='weight')), columns=['Cause', 'Effect', 'Score'])
print(list22)