ucl-pond / pySuStaIn

Subtype and Stage Inference (SuStaIn) algorithm with an example using simulated data.
MIT License
112 stars 62 forks source link

bug in cross-validation when using "select_fold" #11

Closed illdopejake closed 3 years ago

illdopejake commented 4 years ago

Hello!

Parallelizing the cross-validation requires use of the "select_fold" argument. One might, for example, launch 10 instances of cross-validation, one for each of (say) 10 folds, in which case the given fold (say 3) would be passed to the select_fold argument. However, I've run into a few issues with this function.

First, line 218 if select_fold: Many users will pass 0 to get the first fold. However, 0 will not fulfill the conditional statement if select_fold. There are many ways to fix this, but since the default setting for select_fold is [], my janky fix was just: if select_fold != []:

The next issue comes in the following lines, 218-220

if select_fold:
    test_idxs  = test_idxs[select_fold]
Nfolds  = len(test_idxs)

I'm not sure what the intention was here, but the result is that Nfolds actually becomes the number of subjects in the test set. So, instead of having the desired 1 fold, you end up with N folds, where N is the number of subjects in the test set.

My solution here requires a few changes. First, lines 218-220 are changed to this:

if select_fold != []:
    Nfolds = 1
else:
    Nfolds = len(test_idxs)

Then, in order to disrupt the code as little as possible, I added the following lines under line 226. I include 226 below for reference:

for fold in range(Nfolds): 
    if select_fold != []:  # or whatever you change line 218 to
        fold = select_fold

Adding these three small changes resulted in the script working without issue for me, though maybe there are more elegant solutions. Thanks for bringing SuStaIn to Python!!

LeonAksman commented 3 years ago

Hi Jake,

Thanks a lot for raising this issue. I finally got around to this as I'm planning a fairly large update to the master branch. It won't change the algorithm, but will make the code installable as a python package along with a bunch of changes in simrun.py to better showcase pSuStaIn's features.

I agree that the way the code was plain wrong. Looking at your proposed changes, I realized that if the user passes an array of folds it won't work, so maybe we can use this logic instead:

        if select_fold != []:
            Nfolds                          = len(select_fold)
        else:
            select_fold                     = test_idxs
            Nfolds   

This will set the number of folds to what the user passed in or use all folds if the user didn't pass anything in. It will also make sure that select_fold holds the folds to be run.

Then I added this:

        for fold in range(Nfolds):

            indx_train                      = np.array([x for x in range(self.__sustainData.getNumSamples()) if x not in select_fold[fold]])
            indx_test                       = select_fold[fold]

Where I replaced test_idxs[fold] with select_fold[fold] in both lines.

What do you think of this?

illdopejake commented 3 years ago

Hi Leon,

Thanks for looking into this. Sorry I didn't respond earlier. I just got around to having another look at this, and I'm still running into a similar issue. I think the source of issue is ultimately lack of documentation for this function.

So, as I learned from the tutorial notebook, test_idxs is supposed to be a nested list (ie a list of lists). Specifically, there is a length n list of lists containing m indices, where n is the # of folds and m is the number of individuals in the test set for a given fold. (As an aside, this seems less intuitive to me than just an n x m array.)

Then, the user is is prompted to pass the select_fold argument if he/she wishes in a parallel context. The default is an empty list, which to me is unintuitive. Why another list? I would have expected this argument to just be an integer, 0 through n, where n is the number of folds. I'm not sure what the actual argument is supposed to be? It appears from the default (and your prior comment) that it's supposed to be an array, but if the user already did the work to compile the list of lists in the first place, why pass another array for the select fold? I would find it easier to just pass an integer indicating which fold the user wants to use, and the solution I proposed (janky as it may be) allows that.

I may just be misunderstanding something, and this comment might be obviated by documentation explaining the expected input. But I couldn't find any documentation for this function, nor was there an example on the tutorial notebook.

Anyway, I've changed it locally and everything is fine, so consider this just a suggestion and no worries if you disagree! Just some food for thought.

<3 --Jake