juaml / julearn

Forschungszentrum Jülich Machine Learning Library
https://juaml.github.io/julearn
GNU Affero General Public License v3.0
30 stars 19 forks source link

Modify logic for final model training #273

Closed fraimondo closed 1 month ago

fraimondo commented 1 month ago

So far, when the user requested the final model, after calling scikit-learn's cross_validate, julearn was fiting the model again, on the full training data.

The main issue is when using joblib to parallelize, there was a call for each outer CV fold and once it was done, the main process will fit the final model. With enough resources, this is suboptimal, as one might want to fit the final model at the same time of the individual folds.

This PR changes the internal logic so the effect is the same, but the fiting happens at a different time. The idea is to add an "extra" fold in the CV object which includes the whole dataset. After the call to cross_validate is done, we remove the last entry and use this as the final model, obtaining the same output, but allowing the user to use joblib to parallelise together across CV folds and the final model.

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 95.34884% with 2 lines in your changes missing coverage. Please review.

Project coverage is 89.89%. Comparing base (eb7207f) to head (cfb4936). Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
julearn/utils/_cv.py 60.00% 1 Missing and 1 partial :warning:
Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/juaml/julearn/pull/273/graphs/tree.svg?width=650&height=150&src=pr&token=VT2P05ZJCB&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml)](https://app.codecov.io/gh/juaml/julearn/pull/273?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml) ```diff @@ Coverage Diff @@ ## main #273 +/- ## ========================================== + Coverage 89.83% 89.89% +0.05% ========================================== Files 54 56 +2 Lines 2449 2483 +34 Branches 497 504 +7 ========================================== + Hits 2200 2232 +32 - Misses 163 164 +1 - Partials 86 87 +1 ``` | [Flag](https://app.codecov.io/gh/juaml/julearn/pull/273/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml) | Coverage Δ | | |---|---|---| | [docs](https://app.codecov.io/gh/juaml/julearn/pull/273/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml) | `100.00% <ø> (ø)` | | | [julearn](https://app.codecov.io/gh/juaml/julearn/pull/273/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml) | `89.88% <95.34%> (+0.05%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files with missing lines](https://app.codecov.io/gh/juaml/julearn/pull/273?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml) | Coverage Δ | | |---|---|---| | [julearn/api.py](https://app.codecov.io/gh/juaml/julearn/pull/273?src=pr&el=tree&filepath=julearn%2Fapi.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml#diff-anVsZWFybi9hcGkucHk=) | `92.89% <100.00%> (+0.26%)` | :arrow_up: | | [julearn/model\_selection/final\_model\_cv.py](https://app.codecov.io/gh/juaml/julearn/pull/273?src=pr&el=tree&filepath=julearn%2Fmodel_selection%2Ffinal_model_cv.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml#diff-anVsZWFybi9tb2RlbF9zZWxlY3Rpb24vZmluYWxfbW9kZWxfY3YucHk=) | `100.00% <100.00%> (ø)` | | | [julearn/model\_selection/utils.py](https://app.codecov.io/gh/juaml/julearn/pull/273?src=pr&el=tree&filepath=julearn%2Fmodel_selection%2Futils.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml#diff-anVsZWFybi9tb2RlbF9zZWxlY3Rpb24vdXRpbHMucHk=) | `100.00% <100.00%> (ø)` | | | [julearn/utils/\_cv.py](https://app.codecov.io/gh/juaml/julearn/pull/273?src=pr&el=tree&filepath=julearn%2Futils%2F_cv.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=juaml#diff-anVsZWFybi91dGlscy9fY3YucHk=) | `90.24% <60.00%> (-4.21%)` | :arrow_down: |
github-actions[bot] commented 1 month ago

PR Preview Action v1.4.8 :---: Preview removed because the pull request was closed. 2024-09-26 13:51 UTC