FenTechSolutions / CausalDiscoveryToolbox

Package for causal inference in graphs and in the pairwise settings. Tools for graph structure recovery and dependencies are included.
MIT License
1.08k stars 198 forks source link

In cdt.causality.pairwise , the "RCC" example uses "Jarfo" #27

Closed ArnoVel closed 5 years ago

ArnoVel commented 5 years ago

Hi, minor problem, just a confusing example I found through experiments, see below

class RCC(PairwiseModel):
    """Randomized Causation Coefficient model. 2nd approach in the Fast
    Causation challenge.
    **Description:** The Randomized causation coefficient (RCC) relies on the
    projection of the empirical distributions into a RKHS using random cosine
    embeddings, then classfies the pairs using a random forest based on those
    **Data Type:** Continuous, Categorical, Mixed
    **Assumptions:** This method needs a substantial amount of labelled causal
    pairs to train itself. Its final performance depends on the training set
        rand_coeff (int): number of randomized coefficients
        nb_estimators (int): number of estimators
        nb_min_leaves (int): number of min samples leaves of the estimator
        max_depth (): (optional) max depth of the model
        s (float): scaling
        njobs (int): number of jobs to be run on parallel (defaults to ``cdt.SETTINGS.NJOBS``)
        verbose (bool): verbosity (defaults to ``cdt.SETTINGS.verbose``)
    .. note::
       Ref : Lopez-Paz, David and Muandet, Krikamol and Schölkopf, Bernhard and Tolstikhin, Ilya O,
       "Towards a Learning Theory of Cause-Effect Inference", ICML 2015.
        >>> from cdt.causality.pairwise import RCC
        >>> import networkx as nx
        >>> import matplotlib.pyplot as plt
        >>> from cdt.data import load_dataset
        >>> from sklearn.model_selection import train_test_split
        >>> data, labels = load_dataset('tuebingen')
        >>> X_tr, X_te, y_tr, y_te = train_test_split(data, labels, train_size=.5)
        >>> obj = Jarfo()
        >>> obj.fit(X_tr, y_tr)
        >>> # This example uses the predict() method
        >>> output = obj.predict(X_te)
        >>> # This example uses the orient_graph() method. The dataset used
        >>> # can be loaded using the cdt.data module
        >>> data, graph = load_dataset('sachs')
        >>> output = obj.orient_graph(data, nx.DiGraph(graph))
        >>> # To view the directed graph run the following command
        >>> nx.draw_networkx(output, font_size=8)
        >>> plt.show()
diviyank commented 5 years ago

Hi, It should be done, Thanks for the feedback !