victoresque / pytorch-template

PyTorch deep learning projects made easy.
MIT License
4.75k stars 1.09k forks source link

any support of k-fold cross validation? #65

Open emadeldeen24 opened 4 years ago

LumosWC commented 4 years ago

yeah please. Or any suggestions on how to modify the code if I want to use k-fold cross validation please?

SunQpark commented 4 years ago

Sorry for late response. K-fold CV can be done by modifying BaseDataLoader.

Current implementation of _split_sampler use first n samples of dataset as validation set. Modify this function to take additional argument, fold_idx and to use specified part of dataset. To be more specific, this part should be changed to following.

valid_idx = idx_full[len_valid*fold_idx, len_valid*(fold_idx+1)]
train_idx = np.delete(idx_full, np.arange(0, len_valid))

Then, add the fold_idx to BaseDataLoader.__init__ and YourDataLoader.__init__, specify it in the config file, under data_loader.args.

hsinlichu commented 4 years ago

@SunQpark I think you mean

valid_idx = idx_full[len_valid * fold_idx : len_valid * (fold_idx + 1)]
train_idx = np.delete(idx_full, np.arange(len_valid * fold_idx, len_valid * (fold_idx + 1)))
deeperlearner commented 3 years ago

I have implemented cross validation to my Pytorch-Template See #88