meeg-ml-benchmarks / brain-age-benchmark-paper

M/EEG brain age benchmark paper
https://meeg-ml-benchmarks.github.io/brain-age-benchmark-paper/
BSD 3-Clause "New" or "Revised" License
24 stars 10 forks source link

added a fake scorer to save skorch model training history #23

Closed gemeinl closed 2 years ago

hubertjb commented 2 years ago

@gemeinl could we use something like this? https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.TrainEndCheckpoint

Also, I guess it would be pretty complicated to also see the learning curve on the test set?

gemeinl commented 2 years ago

I fear the different folds will overwrite each other if we use this? We do not have access to the fold id within scikit-learns cross_validate to set as dirname, do we? I also don't see a way of changing the file opening mode to something like append.. Do you see a way how to make it work?

If we stick to scikit-learns cross-validate I don't think we can give the model a test/validation set. Scikit-learn creates indices for train and test; it creates subset of the data; it calls estimator.fit(train_X, train_y), it calls estimator.predict(valid_X). On training time the estimator just does not know about the test/validation set..

hubertjb commented 2 years ago

Aah that's a good point @gemeinl. I wonder what actually happens with the Checkpoint callbacks when you call .fit() more than once. It might indeed be that it will overwrite previous outputs, but maybe there could be a way to avoid that. Anyway, I think your approach is a lot simpler, so I vote for that!

About using a validation set: this might be a bit convoluted, but we could override the .fit() method of the skorch estimator to first split X_train and y_train with train_test_split, and then update the estimator with the validation set before calling the actual fit()? This would open the door to doing early stopping as well, which might reduce training time/avoid overfitting on datasets like Cam-CAN.

gemeinl commented 2 years ago

Aah that's a good point @gemeinl. I wonder what actually happens with the Checkpoint callbacks when you call .fit() more than once. It might indeed be that it will overwrite previous outputs, but maybe there could be a way to avoid that. Anyway, I think your approach is a lot simpler, so I vote for that!

I think 'my' approach also only works when we use n_jobs=1 in cross_validate. But since we always do this for the braindecode models we should be good.

About using a validation set: this might be a bit convoluted, but we could override the .fit() method of the skorch estimator to first split X_train and y_train with train_test_split, and then update the estimator with the validation set before calling the actual fit()? This would open the door to doing early stopping as well, which might reduce training time/avoid overfitting on datasets like Cam-CAN.

Yes, that could work. But wouldn't the model then see less training data than the other estimators in the benchmark (due to the additional train / test split)? Performance might not be comparable then. Apart from writing our own cross_validate function I don't see a way right now..

gemeinl commented 2 years ago

Ah, you mean in a hyperparameter tuning run prior to the actual benchmark. Hm, yes.. It might still be tricky with the callbacks. I mean, whether the callbacks that compute the loss on the validation set will then correctly be added or not.

hubertjb commented 2 years ago

Ah, you mean in a hyperparameter tuning run prior to the actual benchmark. Hm, yes.. It might still be tricky with the callbacks. I mean, whether the callbacks that compute the loss on the validation set will then correctly be added or not.

Yes exactly. I haven't finished the hyperparameter tuning part, but the validation set part ended up working (with the callbacks) by using the train_split argument of NeuralNetRegressor with a new object similar to the BraindecodeKFold class you wrote. Also the whole thing seems to be correctly logged with your logger in #24 !

gemeinl commented 2 years ago

Continuing with this in https://github.com/dengemann/meeg-brain-age-benchmark-paper/pull/24