salesforce / OmniXAI

OmniXAI: A Library for eXplainable AI
BSD 3-Clause "New" or "Revised" License
859 stars 92 forks source link

Some bugs in MACE #69

Closed SeibertronSS closed 1 year ago

SeibertronSS commented 1 year ago

I often get the following two errors when using MACE on tabular data

Traceback (most recent call last):
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/base.py", line 284, in explain
    explanations[name] = self.explainers[name].explain(X=X, **param)
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/tabular/counterfactual/mace/mace.py", line 119, in explain
    candidates, indices = self.recall.get_cf_features(x, desired_label)
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/tabular/counterfactual/mace/retrieval.py", line 187, in get_cf_features
    y, indices = self.get_nn_samples(instance, desired_label)
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/tabular/counterfactual/mace/retrieval.py", line 174, in get_nn_samples
    indices = self._knn_query(query, desired_label, self.num_neighbors)[0]
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/tabular/counterfactual/mace/retrieval.py", line 122, in _knn_query
    indices, distances = self.knn_models[label].knn_query(x, k=k)
RuntimeError: Cannot return the results in a contigious 2D array. Probably ef or M is too small
Traceback (most recent call last):
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/base.py", line 284, in explain
    explanations[name] = self.explainers[name].explain(X=X, **param)
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/tabular/counterfactual/mace/mace.py", line 139, in explain
    cfs_df = cfs.to_pd()
AttributeError: 'NoneType' object has no attribute 'to_pd'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "z/mace.py", line 99, in <module>
    explanations = explainers.explain(test_instances)
  File "/export/home/xxx/miniconda3/envs/kf3/lib/python3.8/site-packages/omnixai/explainers/base.py", line 286, in explain
    raise type(e)(f"Explainer {name} -- {str(e)}")
AttributeError: Explainer mace -- 'NoneType' object has no attribute 'to_pd'

The phenomenon is that I use the same code, but there will be two kinds of unreasonable errors. The previous CPU error https://github.com/salesforce/OmniXAI/issues/68 seems to be because these tracebacks are not shown in the container.

yangwenz commented 1 year ago

Hi, it looks like the first problem comes from the KNN index, for this problem, could you share the data and the model so that we can reproduce the results? The second problem is a bit weired. It seems that there is some problem in the input features. Also, if the features have some formatting issues, the first problem may also occur. So could you provide some samples of your datasets and the preprocessing function?

SeibertronSS commented 1 year ago

I did some debugging yesterday. The first problem is that the specified number of neighbors cannot be found due to the small training dataset. I made a fault tolerance in _knn_query method of retrieval.py.

def _knn_query(self, x: np.ndarray, label: int, k: int) -> List:
        """
        Finds the nearest neighbors given a query instance.

        :param x: The query instance.
        :param label: The desired label.
        :param k: The number of neighbors.
        :return: A list of the indices of the nearest neighbors.
        """
        try:
            indices, distances = self.knn_models[label].knn_query(x, k=k)
            neighbors = [[idx[i] for i in range(len(idx)) if dists[i] > 0] for idx, dists in zip(indices, distances)]
        except RuntimeError:
            neighbors = self.knn_models[label].get_ids_list()
        return neighbors

The second problem is that self.diversity.get_diverse_cfs sometimes returns None in the explain method of mace.py, causing self.refinement.refine to return None as well. I also made a fault tolerance here

def explain(
            self,
            X: Tabular,
            y: Union[List, np.ndarray] = None,
            max_number_examples: int = 5,
            **kwargs
    ) -> CFExplanation:
        """
        Generates counterfactual explanations.

        :param X: A batch of input instances. When ``X`` is `pd.DataFrame`
            or `np.ndarray`, ``X`` will be converted into `Tabular` automatically.
        :param y: A batch of the desired labels, which should be different from the predicted labels of ``X``.
            If ``y = None``, the desired labels will be the labels different from the predicted labels of ``X``.
        :param max_number_examples: The maximum number of the generated counterfactual
            examples per class for each input instance.
        :return: A CFExplanation object containing the generated explanations.
        """
        # print("x", X)
        # print("y", y)
        if y is not None:
            assert len(y) == X.shape[0], (
                f"The length of `y` should equal the number of instances in `X`, " f"got {len(y)} != {X.shape[0]}"
            )
        X = self._to_tabular(X).remove_target_column()
        scores = self.predict_function(X)
        # print("scores", scores)
        labels = np.argmax(scores, axis=1)
        num_classes = scores.shape[1]
        # print("labels", labels)
        # print("num_classes", num_classes)

        explanations = CFExplanation()
        for i in range(X.shape[0]):
            x = X.iloc(i)
            label = int(labels[i])
            if y is None or y[i] == label:
                desired_labels = [z for z in range(num_classes) if z != label]
            else:
                desired_labels = [int(y[i])]

            all_cfs = []
            for desired_label in desired_labels:
                # Get candidate features
                candidates, indices = self.recall.get_cf_features(x, desired_label)

                # Find counterfactual examples
                examples = self.solver.get_cf_examples(self.predict_function, x, desired_label, candidates)
                if not examples:
                    # If GLD fails, try to apply the greedy method
                    examples = Greedy().get_cf_examples(
                        self.predict_function, x, desired_label, candidates)

                # Generate diverse counterfactual examples
                if examples:
                    cfs = self.diversity.get_diverse_cfs(
                        self.predict_function, x, examples["cfs"],
                        oracle_function=lambda _s: int(desired_label == np.argmax(_s)),
                        desired_label=desired_label, k=max_number_examples
                    )
                    cfs = self.refinement.refine(
                        self.predict_function, x, cfs,
                        oracle_function=lambda _s: int(desired_label == np.argmax(_s))
                    )
                    if cfs:
                        cfs_df = cfs.to_pd()
                        cfs_df["label"] = desired_label
                        all_cfs.append(cfs_df)

            instance_df = x.to_pd()
            instance_df["label"] = label
            explanations.add(query=instance_df, cfs=pd.concat(all_cfs) if len(all_cfs) > 0 else None)
        return explanations
SeibertronSS commented 1 year ago

I'm very sorry, due to company regulations, I can't provide the model and dataset I used. I've found that MACE has these problems with particularly small datasets. The tabular data I am using has only 100 rows.

SeibertronSS commented 1 year ago

In addition to the above two problems, I also found another problem. When MACE creates desired_labels, the default model can predict all categories in the training set, but when the model performance is particularly poor, the model may not be able to predict some categories. This will cause the corresponding knn_model not to be found when doing KNN. The specific phenomenon is that in the _knn_query method of retrieval.py, indices, distances = self.knn_models[label].knn_query(x, k=k) will raise KeyError

yangwenz commented 1 year ago

So this small dataset, we can implement a naive KNN search method instead of using hnswlib. If you think it is necessary, we can add it into the new version.

yangwenz commented 1 year ago

For KNN, if the predictions have all the same class/label, there will be an issue when doing KNN query because the desired label is not in the KNN index. In this case, MACE can only be applied without using the KNN retrieval module.

yangwenz commented 1 year ago

If you are in hurry for this project, the code you can modify is: "candidates, indices = self.recall.get_cf_features(x, desired_label)" in mace.py. You can replace this function with a function that returns all the features values as the candidate features. Because the dataset is small, it is OK to do this.

SeibertronSS commented 1 year ago

I made some fault tolerance to OmniXAI's code so that MACE can work even with small datasets and poor model performance. I think MACE should support small datasets and poorly performing models.

sarangs-ntnu commented 1 year ago

I am getting below error in MACE

RuntimeError Traceback (most recent call last) ~/anaconda3/envs/ntnu_meticos/lib/python3.7/site-packages/omnixai/explainers/base.py in explain(self, X, params, run_predict) 283 param = params.get(name, {}) --> 284 explanations[name] = self.explainers[name].explain(X=X, **param) 285 except Exception as e:

~/anaconda3/envs/ntnu_meticos/lib/python3.7/site-packages/omnixai/explainers/tabular/counterfactual/mace/mace.py in explain(self, X, y, max_number_examples, **kwargs) 124 # Get candidate features --> 125 candidates, indices = self.recall.get_cf_features(x, desired_label) 126

~/anaconda3/envs/ntnu_meticos/lib/python3.7/site-packages/omnixai/explainers/tabular/counterfactual/mace/retrieval.py in get_cf_features(self, instance, desired_label) 186 x = instance.to_pd(copy=False) --> 187 y, indices = self.get_nn_samples(instance, desired_label) 188 cate_candidates, cont_candidates = {}, {}

~/anaconda3/envs/ntnu_meticos/lib/python3.7/site-packages/omnixai/explainers/tabular/counterfactual/mace/retrieval.py in get_nn_samples(self, instance, desired_label) 173 ) --> 174 indices = self._knn_query(query, desired_label, self.num_neighbors)[0] 175 y = self.subset.iloc(indices).to_pd(copy=False)

~/anaconda3/envs/ntnu_meticos/lib/python3.7/site-packages/omnixai/explainers/tabular/counterfactual/mace/retrieval.py in _knn_query(self, x, label, k) 121 """ --> 122 indices, distances = self.knn_models[label].knn_query(x, k=k) 123 neighbors = [[idx[i] for i in range(len(idx)) if dists[i] > 0] for idx, dists in zip(indices, distances)]

RuntimeError: Cannot return the results in a contigious 2D array. Probably ef or M is too small

During handling of the above exception, another exception occurred:

RuntimeError Traceback (most recent call last) /tmp/ipykernel_3818713/1430780571.py in 14 # Generate explanations 15 test_instances = test_data[0:7] ---> 16 local_explanations = explainers.explain(X=test_instances) 17 global_explanations = explainers.explain_global( 18 params={"pdp": {"features": []}}

~/anaconda3/envs/ntnu_meticos/lib/python3.7/site-packages/omnixai/explainers/base.py in explain(self, X, params, run_predict) 284 explanations[name] = self.explainers[name].explain(X=X, **param) 285 except Exception as e: --> 286 raise type(e)(f"Explainer {name} -- {str(e)}") 287 return explanations 288

RuntimeError: Explainer mace -- Cannot return the results in a contigious 2D array. Probably ef or M is too small