SupervisedStylometry / SuperStyl

Supervised Stylometry
GNU General Public License v3.0
21 stars 5 forks source link

Missing a variable declaration for training cross_validation other than group-k-fold #35

Closed Louis-Fiacre closed 1 year ago

Louis-Fiacre commented 1 year ago

In superstyl/svm.py, for training with other choices than group-k-fold cross_validation, the works variable must be declared upstream and out the conditionnal choice of cross_validate == 'group-k-fold'. Otherwise in skmodel.cross_val_predict() the arg group=works is undefined.

if cross_validate is not None:

        works = None # need it there !

        if cross_validate == 'leave-one-out':
            myCV = skmodel.LeaveOneOut()

        if cross_validate == 'k-fold':
            myCV = skmodel.KFold(n_splits=k)

        if cross_validate == 'group-k-fold':
            # Get the groups as the different source texts
            works = [t.split('_')[0] for t in train.index.values]
            myCV = skmodel.GroupKFold(n_splits=len(set(works)))

        print(".......... "+ cross_validate +" cross validation will be performed ........")
        print(".......... using " + str(myCV.get_n_splits(train)) + " samples or groups........")

        # Will need to
        # 1. train a model
        # 2. get prediction
        # 3. compute score: precision, recall, F1 for all categories

        preds = skmodel.cross_val_predict(pipe, train, classes, cv=myCV, verbose=1, n_jobs=-1, groups=works)
Jean-Baptiste-Camps commented 1 year ago

Thanks a lot ! I've fixed that using the exact same solution in the latest merge, #36