FenTechSolutions / CausalDiscoveryToolbox

Package for causal inference in graphs and in the pairwise settings. Tools for graph structure recovery and dependencies are included.
https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/index.html
MIT License
1.08k stars 198 forks source link

NCC outputs a continuous value #35

Closed siddsuresh97 closed 4 years ago

siddsuresh97 commented 4 years ago

The NCC code is supposed to output 1 or -1 but that does not happen. I tried it on the example used in the documentation.

diviyank commented 4 years ago

It is supposed to output a prediction between 1 and -1, and the closer of the two is retained when evaluating the prediction of NCC. A neural network architecture is unable to output only integer values, unless if a special operation is added, such as gumbel softmax. Best, Diviyan

siddsuresh97 commented 4 years ago

The predictions do not lie between -1 and 1. X_tr, X_te, y_tr, y_te = train_test_split(data, labels, train_size=.8) obj = NCC() obj.fit(X_tr, y_tr, epochs=100, batch_size=32, learning_rate=0.01, verbose=None, device=None) obj.predict(X_te) These are the predictions I got [2.314418315887451, 0.47210752964019775, 23.46441650390625, 14.973156929016113, 1.0354838371276855, -12.186832427978516, -12.370292663574219, 8.425238609313965, -7.850937843322754, -10.962465286254883, 7.403403282165527, 24.28883934020996, -3.074826955795288, -9.25027084350586, 18.682098388671875, -4.038880825042725, 15.27067756652832, 17.017026901245117, -11.703863143920898, 7.830548286437988, -6.0763840675354, -3.0923893451690674, 9.908377647399902, 8.449213981628418, -8.861017227172852, 0.7335888147354126, -6.0915446281433105, 3.7071075439453125, 1.5803894996643066, -0.9877095818519592, 16.72327995300293, -10.99417495727539, 16.568437576293945, 34.17158889770508, 15.287670135498047, -8.461350440979004, 3.9821252822875977, -12.191576957702637, -8.865286827087402, -9.33437728881836, -2.64302921295166, -14.96704387664795, 0.14124274253845215, -13.85118579864502, 16.036678314208984, 0.6504915952682495, 2.2593564987182617, -5.6124587059021, 1.0865154266357422, 2.718655586242676, 6.070904731750488, -1.6735219955444336, -6.653247833251953, 8.754668235778809, 10.08611011505127, 1.278486728668213, -10.487288475036621, -2.2547359466552734, 10.08383560180664, 0.05527770519256592, -15.76795768737793, -1.3365739583969116, 0.9914795160293579, 17.387306213378906, 16.008872985839844, 12.43630599975586, 0.8231788873672485, -10.948040962219238, -2.858363628387451, 0.5532840490341187, -13.144486427307129, 7.203824043273926, -10.126144409179688, -3.24215030670166, -6.269499778747559, 0.10222327709197998, -9.891613006591797, -11.992828369140625, -11.922747611999512, 3.363156318664551, 19.12946128845215, 3.2570786476135254, 26.269838333129883, 8.233927726745605, -14.672187805175781, -0.0735774040222168, -11.518845558166504, -1.6137429475784302, 0.04899883270263672, 3.669816493988037, -1.6894885301589966, 22.354814529418945, 1.2691173553466797, 0.3988865613937378, -10.518499374389648, -5.545954704284668, -6.968181610107422, 17.4660587310791, 4.580349445343018, 9.723427772521973]

diviyank commented 4 years ago

My bad, they are meant to be real values. The higher the absolute value, the higher the confidence of the algorithm. Is your NCC trained ?

siddsuresh97 commented 4 years ago

Oh ok! Just to re-iterate, this means that all positive values correspond to 1 and the negative values correspond to the label -1 and the value of 26 means that there is more probability of the sample corresponding to the label 1 than a value of 8. Similarly the value of -26 means the confidence of the sample corresponding to the label -1 is greater that -8. As of now the library is not very flexible with training the NCC (like using different optimizers, finding out validation accuracy so that I know when to stop training, etc. ). So I'm not sure if my NCC is trained. I got those results after training the NCC on 10,000 samples generated using:-

from cdt.data import CausalPairGenerator generator = CausalPairGenerator('gp_add') data, labels = generator.generate(10000, npoints=500)

How do I be sure that the model is trained?

diviyank commented 4 years ago

Yes that is correct

The 'CDT' is actually designed as a high level library, and enabling all possible tweaks to the models might complexify the API by a significant margin. To modify the models and learning procedures as you would like you could always use the NCC architecture (from the code ; I will make them accessible from the API directly in the next version) and rewrite the learning procedure.

To have an idea on the training, activating the verbosity and checking the evolution of the loss would be a good indicator.

Best, Diviyan

diviyank commented 4 years ago

It should be done; I'll be closing this issue for now,