afeinstein20 / stella

For characterizing flares with convolutional neural networks
https://adina.feinste.in/stella
MIT License
26 stars 18 forks source link

cross validation error: ValueError: Inconsistent data column lengths #25

Open dani753 opened 2 years ago

dani753 commented 2 years ago
def cross_validation(self, seed=2, epochs=350, batch_size=64,
                     n_splits=5, shuffle=False, pred_test=False, save=False):
    """

    """

    from sklearn.model_selection import KFold
    from sklearn.metrics import precision_recall_curve
    from sklearn.metrics import average_precision_score

    num_flares = len(self.labels)
    trainval_cutoff = int(0.90 * num_flares)

    remaining = trainval_cutoff%n_splits
    trainval_cutoff-= remaining
dani753 commented 2 years ago

It is necessary to add two lines of code after defining trainval_cutoff to avoid this error. They are: remaining = trainval_cutoff%n_splits trainval_cutoff -= remaining