dbsxfz / AugDF

code for the paper 'Improve Deep Forest with Learnable Layerwise Augmentation Policy Schedule' submitted to ICASSP 2024
MIT License
4 stars 1 forks source link

Support scikit-learn compatible #1

Open arandomgoodguy opened 9 months ago

arandomgoodguy commented 9 months ago

Hi, very interesting work!

Is there any chance for AugDF to support scikit-learn fit API, so we can use something like AugDF.fit().

It can help compare performance of multiple tabular models if they are all in one format, thanks!

dbsxfz commented 8 months ago

I understand your concerns. We acknowledge that we haven't provided clear function interfaces and sufficient usage examples, which has led to low code readability and extensibility. We plan to address this in the coming days, and please note that the current code is not the final version.

There are two reasons why we didn't adopt the fit-like approach initially. Firstly, our model requires searching for augmentation parameters for each layer, which means multiple fit calls might be needed, deviating from the original purpose of fit. Tonight, we will provide a fit function interface that utilizes the already searched policy_schedule.

Secondly, as described in our paper, DF is susceptible to overfitting, and we wanted to monitor the performance on the validation and test sets after each layer's training for simplicity. We understand this may raise concerns about potential test set leakage, and we plan to separate training and testing phases in the future. Please feel free to scrutinize our current code; there is no behavior indicating test set leakage.

dbsxfz commented 8 months ago

We have updated the fit function uses the searched policy. fit function for searching will be updated soon.

arandomgoodguy commented 8 months ago

Hi, thanks for your quick response and implementation.

I have 2 more questions after I take a deeper look into it.

  1. Eval Metric

I noticed that the model can only evaluated by accuracy now.

But other metrics may be more appropriate to use in reality, for instance, if you have a very imbalanced dataset, you may not want to evaluate the model using accuracy only.

So can we have a wrapper to include all scikit-learn metrics, like F1 score?

  1. Validation Split

I noticed that you have train valid split in your model.

StratifiedKFold( self.kf_N, shuffle=True, random_state=random_state)#default as 42

I wonder if the train valid split can be implemented outside the model since in some case user may have one fixed validation set that split by them before throwing into model.

Like some of the tabular models, you can do something like this

model.fit(X_train, y_train, eval_set = (X_valid, y_valid))

I know that the Deep Forest itself doesn't support 1. and 2. as well, so just curious is that possible? Thank you.

OswaldHongyu commented 8 months ago

Hi @arandomgoodguy, thank you for your attention to our work.

Firstly, regarding the test set leakage mentioned by @dbsxfz, I want to emphasize that the test data does not engage in the training process at all. It is solely utilized to monitor the performance fluctuation during the layer stacking phase and among individuals in the population. It can be observed that model training remains unaffected even when the test data is replaced by zero matrices in our previous implementation, where test data was passed to the training function. Nonetheless, we have now entirely segregated the training and testing procedures in the current version. For monitoring purposes, certain functions from the previous version may be incorporated in the upcoming release. Note that this repository is presently undergoing reorganization, with numerous modifications and amendments in the pipeline.

Regarding the Eval Metric, it is indeed reasonable to substitute accuracy with other metrics in specific settings. At present, our primary focus is the overall performance in tabular classification tasks, for which accuracy serves as a straightforward, initial metric. We may consider additional evaluations once the primary amendments are completed. Modifying the evaluation metric is straightforward in our current implementation. You can easily adjust the 'compute_accuracy' function here to accommodate different metrics, which will, in turn, impact both the searching process and the final evaluation.

Regarding the Validation Split, the use of cross-validation in AugDF serves dual purposes. First, the validation probabilities are employed as augmented features, akin to the approach in vanilla Deep Forest. Second, cross-validation scores aid in the searching procedure. Should an additional validation set be available, it's entirely feasible to learn the policy schedule on this true validation set, while reserving the inherent cross-validation solely for stacking layers. We may integrate this feature after our ongoing modifications are completed.

Thank you again for your insightful observations.