amparore / leaf

A Python framework for the quantitative evaluation of eXplainable AI methods
16 stars 8 forks source link

What does the leaf.train() do? #2

Closed singsinghai closed 1 year ago

singsinghai commented 1 year ago

Hi Amparore, I wonder why you retrain the model with your own function as follow. As far as I can understand,

if use_weights:
        nT = sum(labels_train)
        nF = len(labels_train) - nT
        weights_train = ((labels_train==False)*nT + (labels_train==True)*nF) / len(labels_train)

        model.fit(train, labels_train, sample_weight=weights_train)

means that you want to give weight to class in this classifier. Why should this be important? Is it ok if I proceed to use the LEAF class with the original model?

amparore commented 1 year ago

i think it was just to work well with unbalanced datasets. Of course you can rebalance the dataset beforehand, without relying on sample weights, or train directly using unbalanced data. You can definitively skip this part of the code, if you don't need it.

amparore commented 1 year ago

and, just to be clear, the train_model() function is just for convenience. You can train your models as you like, since the LEAF class takes in input an arbitrary trained model (that does not need to be trained with leaf.train_model()).