Open emadeldeen24 opened 4 years ago
Sorry for late response. K-fold CV can be done by modifying BaseDataLoader
.
Current implementation of _split_sampler
use first n samples of dataset as validation set. Modify this function to take additional argument, fold_idx
and to use specified part of dataset.
To be more specific, this part should be changed to following.
valid_idx = idx_full[len_valid*fold_idx, len_valid*(fold_idx+1)]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
Then, add the fold_idx
to BaseDataLoader.__init__
and YourDataLoader.__init__
, specify it in the config file, under data_loader.args
.
@SunQpark I think you mean
valid_idx = idx_full[len_valid * fold_idx : len_valid * (fold_idx + 1)]
train_idx = np.delete(idx_full, np.arange(len_valid * fold_idx, len_valid * (fold_idx + 1)))
I have implemented cross validation to my Pytorch-Template See #88
yeah please. Or any suggestions on how to modify the code if I want to use k-fold cross validation please?