LAMDA-NJU / Deep-Forest

An Efficient, Scalable and Optimized Python Framework for Deep Forest (2021.2.1)
https://deep-forest.readthedocs.io
Other
908 stars 158 forks source link

Layer with the key layer_0 already exists in the internal container #102

Closed simonprovost closed 2 years ago

simonprovost commented 2 years ago

Hi guys,

I would like to use the following code to conduct my deep forest configuration (i.e., with my custom base estimators) within a k-fold cross validation process:

k = 5
kf = KFold(n_splits=k, random_state=None)

acc_score = []
auroc_score = []

for train_index , test_index in kf.split(X):
    X_train , X_test = X.iloc[train_index,:],X.iloc[test_index,:]
    y_train , y_test = y[train_index] , y[test_index]

    model.fit(X_train.values,y_train.values)
    pred_values = model.predict(X_test.values)

    predict_prob = model.predict_proba(X_test.values)[:,1]
    auroc = roc_auc_score(y_test, predict_prob)
    acc = accuracy_score(pred_values , y_test)
    auroc_score.append(auroc)
    acc_score.append(acc)

avg_acc_score = sum(acc_score)/k

print('accuracy of each fold - {}'.format(acc_score))
print('Avg accuracy : {}'.format(avg_acc_score))
print('AUROC of each fold - {}'.format(auroc_score))
print('Avg AUROC : {}'.format(sum(auroc_score)/k))

Note: The model here has been declared it is the CascadeClassifier.

However, on the second loop, I received the following error: 'RuntimeError: Layer with the key layer 0 already exists in the internal container.' This is due to the fact that the classifier is not reinstantiated, which results in a conflation with some self variables I imagine. My current patch is to reinstantiate the deep forest classifier at the beginning of my loop, however I am not sure if this is appropriate. Do you have another suggestion?

The following is the traceback:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/var/folders/xd/hlwb4c2s2qd9rn1b3bgs_1g80000gn/T/ipykernel_6711/3506965433.py in <module>
     23     y_train , y_test = y[train_index] , y[test_index]
     24 
---> 25     model.fit(X_train.values,y_train.values)
     26     pred_values = model.predict(X_test.values)
     27 

~/.local/lib/python3.8/site-packages/deepforest/cascade.py in fit(self, X, y, sample_weight)
   1395         y = self._encode_class_labels(y)
   1396 
-> 1397         super().fit(X, y, sample_weight)
   1398 
   1399     def predict_proba(self, X):

~/.local/lib/python3.8/site-packages/deepforest/cascade.py in fit(self, X, y, sample_weight)
    810 
    811         # Add the first cascade layer, binner
--> 812         self._set_layer(0, layer_)
    813         self._set_binner(0, binner_)
    814         self.n_layers_ += 1

~/.local/lib/python3.8/site-packages/deepforest/cascade.py in _set_layer(self, layer_idx, layer)
    572                 " container."
    573             )
--> 574             raise RuntimeError(msg.format(layer_key))
    575 
    576         self.layers_.update({layer_key: layer})

Any help would be so grateful, Thank you all the best Best wishes;

xuyxu commented 2 years ago

Thanks for reporting @simonprovost, will take a look at this in a few days ;-)

xuyxu commented 2 years ago

Sorry for the late response @simonprovost, to use cross validation with deep forest, I think that you should declare your deep forest model inside the loop. Simply calling the fit method several times inside the loop is essentially refitting the model, which is a functionality not supported yet.

xuyxu commented 2 years ago

Closed due to inactivity

simonprovost commented 2 years ago

@xuyxu No worries; I apologise for the delayed response. We were able to independently do cross validation by combining it into a background loop and averaging the results. It worked effectively. Hopefully, the DF will include this feature in the near future; Otherwise, I will submit a pull request as soon as feasible during the year/next year.

Cheers !

xuyxu commented 2 years ago

Let me know if you need any help, thanks ;-)