FenTechSolutions / CausalDiscoveryToolbox

Package for causal inference in graphs and in the pairwise settings. Tools for graph structure recovery and dependencies are included.
https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/index.html
MIT License
1.08k stars 198 forks source link

In cdt.causality.pairwise , the "RCC" example uses "Jarfo" #27

Closed ArnoVel closed 5 years ago

ArnoVel commented 5 years ago

Hi, minor problem, just a confusing example I found through experiments, see below

class RCC(PairwiseModel):
    """Randomized Causation Coefficient model. 2nd approach in the Fast
    Causation challenge.
    **Description:** The Randomized causation coefficient (RCC) relies on the
    projection of the empirical distributions into a RKHS using random cosine
    embeddings, then classfies the pairs using a random forest based on those
    features.
    **Data Type:** Continuous, Categorical, Mixed
    **Assumptions:** This method needs a substantial amount of labelled causal
    pairs to train itself. Its final performance depends on the training set
    used.
    Args:
        rand_coeff (int): number of randomized coefficients
        nb_estimators (int): number of estimators
        nb_min_leaves (int): number of min samples leaves of the estimator
        max_depth (): (optional) max depth of the model
        s (float): scaling
        njobs (int): number of jobs to be run on parallel (defaults to ``cdt.SETTINGS.NJOBS``)
        verbose (bool): verbosity (defaults to ``cdt.SETTINGS.verbose``)
    .. note::
       Ref : Lopez-Paz, David and Muandet, Krikamol and Schölkopf, Bernhard and Tolstikhin, Ilya O,
       "Towards a Learning Theory of Cause-Effect Inference", ICML 2015.
    Example:
        >>> from cdt.causality.pairwise import RCC
        >>> import networkx as nx
        >>> import matplotlib.pyplot as plt
        >>> from cdt.data import load_dataset
        >>> from sklearn.model_selection import train_test_split
        >>> data, labels = load_dataset('tuebingen')
        >>> X_tr, X_te, y_tr, y_te = train_test_split(data, labels, train_size=.5)
        >>>
        >>> obj = Jarfo()
        >>> obj.fit(X_tr, y_tr)
        >>> # This example uses the predict() method
        >>> output = obj.predict(X_te)
        >>>
        >>> # This example uses the orient_graph() method. The dataset used
        >>> # can be loaded using the cdt.data module
        >>> data, graph = load_dataset('sachs')
        >>> output = obj.orient_graph(data, nx.DiGraph(graph))
        >>>
        >>> # To view the directed graph run the following command
        >>> nx.draw_networkx(output, font_size=8)
        >>> plt.show()
    """
diviyank commented 5 years ago

Hi, It should be done, Thanks for the feedback !