ucl-pond / pySuStaIn

Subtype and Stage Inference (SuStaIn) algorithm with an example using simulated data.
MIT License
130 stars 63 forks source link

IndexError during CV #59

Open katrinaCode opened 2 months ago

katrinaCode commented 2 months ago

Hi all,

Wanted to submit a fix for an occasional error I get during CV. The error is as follows:

Traceback (most recent call last):
  File "", line 176, in <module>
    CVIC, loglike_matrix     = sustain_input.cross_validate_sustain_model(test_idxs)
  File "AbstractSustain.py", line 294, in cross_validate_sustain_model
    sustainData_test                = self.__sustainData.reindex(indx_test)
  File "ZscoreSustain.py", line 54, in reindex
    return ZScoreSustainData(self.data[index,], self.__numStages)
IndexError: arrays used as indices must be of integer (or boolean) type

And my fix is simply to explicitly define index_test as an array of integers in line 277 of AbstractSustain:

indx_test                       = (test_idxs[fold]).astype(int)

Would be interested to hear any theories as to why this error happens irregularly; this will happen with some models but not others running on identical versions of my notebook. In that section of my notebook, I follow the SuStaIn workshop essentially verbatim:

labels = sustain_data[label_column].values
cv = sklearn.model_selection.StratifiedKFold(n_splits=N_folds, shuffle=True, random_state=3)
cv_it = cv.split(sustain_data, labels)

# SuStaIn currently accepts ragged arrays, which will raise problems in the future.
# We'll have to update this in the future, but this will have to do for now
test_idxs = []
for train, test in cv_it:
    test_idxs.append(test)
test_idxs = np.array(test_idxs,dtype='object')

for i, (train_index, test_index) in enumerate(cv.split(sustain_data, labels)):
  print(f"Fold {i}:")
  print(f"  Train: index={train_index}")
  print(f"  Test:  index={test_index}")
# perform cross-validation and output the cross-validation information criterion and
# log-likelihood on the test set for each subtypes model and fold combination
CVIC, loglike_matrix     = sustain_input.cross_validate_sustain_model(test_idxs)

Thanks 😊

xullllllll commented 1 month ago

I've never seen a error like that.Can I ask you a question?