amparore / leaf

A Python framework for the quantitative evaluation of eXplainable AI methods
16 stars 8 forks source link

More work for multi-label classification #7

Open singsinghai opened 1 year ago

singsinghai commented 1 year ago

Hi @amparore,

In the previous commit I made changes to help LEAF display for mutliclass problem. However, I noticed that the fidelity score is always very low (close to 0), then I realized the problem came from eval_whitebox_classifier function due to these lines:

BBY0, WBY0 = bb_classifier(NX0)[:,0], g.predict(SNX0)
BBY1, WBY1 = bb_classifier(NX1)[:,0], g.predict(SNX1)
if label_x0 == 1:
    WBY0, WBY1 = 1 - WBY0, 1 - WBY1
BBCLS0, WBCLS0 = BBY0 > 0.5, WBY0 > 0.5
BBCLS1, WBCLS1 = BBY1 > 0.5, WBY1 > 0.5

In the first two lines, you retrieved the output probability of the 1st class and revert the probability if the label is 1 in the if statement. To make it general for multiclass, I suggest changing the code above to:

BBY0, WBY0 = bb_classifier(NX0)[:,label_x0], g.predict(SNX0)      # The label_x0 is now the index of the bb class
BBY1, WBY1 = bb_classifier(NX1)[:,label_x0], g.predict(SNX1)      # while the g.predict already got coef and intercept from the bb class earlier

BBCLS0, WBCLS0 = BBY0 > 0.5, WBY0 > 0.5
BBCLS1, WBCLS1 = BBY1 > 0.5, WBY1 > 0.5

Now the fidelity becomes normal again

image

I will check to find if there are further lines to fix to fit the multiclass, but I think the change above is already fulfill. Would you mind if I create a commit for this afterwards?

amparore commented 1 year ago

It should be correct, as the fidelity is defined as the accuracy over the predicted class (predicted by the two classifiers). Note also that at line 149-150 I make the assumption that the separation between the classes is at 0.5. This could also need some change to be used with multiclass problems.

singsinghai commented 1 year ago

Because we only consider the class Y of the instance to explain, the problem now becomes the binary classification problem of whether we predict class Y or not. I think the threshold 0.5 is ok.

The only problem is that we explain an instance with all class probability below 0.5, which is the weird datapoint that model cant predict well and LEAF produces bad metrics. This will indicate our explanation is very bad for this instance (which is true since even the bb model cant produce good prediction)

amparore commented 1 year ago

in that part, multiclass should probably treated differently, maybe taking the argmax of the output distribution, instead of using thresholds.

singsinghai commented 1 year ago

I understand the idea. The code should now become:

# np.argmax(proba_list, axis=1) will have shape(NX_len, 1)
BBY0, WBY0 = np.argmax(bb_classifier(NX0), axis=1), g.predict(SNX0)     
BBY1, WBY1 = np.argmax(bb_classifier(NX1), axis=1), g.predict(SNX1)      

BBCLS0, WBCLS0 = BBY0 == label_x0, WBY0 > 0.5
BBCLS1, WBCLS1 = BBY1 == label_x0, WBY1 > 0.5
amparore commented 1 year ago

My though was that there should probably be a difference when the output is a single value or a probability distribution. Hence the code should probably look like:

    BBY0, WBY0 = bb_classifier(NX0)[:,0], g.predict(SNX0)
    BBY1, WBY1 = bb_classifier(NX1)[:,0], g.predict(SNX1)
    if len(BBY0)==1: # single class predictor, use a threshold
        if label_x0 == 1:
            WBY0, WBY1 = 1 - WBY0, 1 - WBY1
        BBCLS0, WBCLS0 = BBY0 > 0.5, WBY0 > 0.5
        BBCLS1, WBCLS1 = BBY1 > 0.5, WBY1 > 0.5
    else: # multiclass predictor, determine the top class
        ...
singsinghai commented 1 year ago

I think you might have forgot about the use of the first 2 lines, so I wanna re-clarifiy a few points here just to make sure:

Therefore in this case, I'm pretty sure my suggested code above can do more correctly for both single class and multi class problem:

I understand the idea. The code should now become:

# np.argmax(proba_list, axis=1) will have shape(NX_len, 1)
BBY0, WBY0 = np.argmax(bb_classifier(NX0), axis=1), g.predict(SNX0)     
BBY1, WBY1 = np.argmax(bb_classifier(NX1), axis=1), g.predict(SNX1)      

BBCLS0, WBCLS0 = BBY0 == label_x0, WBY0 > 0.5
BBCLS1, WBCLS1 = BBY1 == label_x0, WBY1 > 0.5

We dont need a threshold for BB prediction as well since we extract the class with highest probability and compare it with label_x0.

amparore commented 1 year ago

ah sure, the argmax on the 1-length vector does nothing.

singsinghai commented 1 year ago

I also tested and it ran OK for both classification problem. Do you think there's any other thing I should check on about this?

amparore commented 1 year ago

Can you run the provided notebook without any problem, replicating the results ?