quanteda / quanteda.classifiers

quanteda textmodel extensions for classifying documents
21 stars 2 forks source link

crossval summarizes over folds before classes #27

Open pchest opened 4 years ago

pchest commented 4 years ago

I noticed a minor issue with the sequence in which the summarize_results function performs computations. Namely, this function aggregates over the folds dimension, and it is only after this output is produced that the crossval function aggregates over class (if by_class = TRUE). This may be problematic, because in multi-class problems (in which the performance of the model across all classes is the outcome of interest), precision, recall, and f1 functions are calculated within each class and then aggregated across all classes to produce a final score for the model being evaluated. By aggregating across folds first, you will obtain a different result. I propose that we simply move the summ <- apply(summ, 2, mean) step inside of the summarize_results function, to a line preceding apply(x_array, c(1, 2), mean).

crossval <- function(x, k = 5, by_class = FALSE) {
    ...
    summ <- summarize_results(results)
    if (!by_class)
        summ <- apply(summ, 2, mean)
    cat("Cross-validation:\n\nMean results for k =", k, "folds:\n\n")
    print(summ)
    invisible(summ)
}

summarize_results <- function(x) {
    # remove the "obs"
    x <- lapply(x, function(y) y[-which(names(y) == "obs")])

    # make into a 3D array
    x_df <- lapply(x, data.frame)
    x_array <- array(unlist(x), dim <- c(dim(x_df[[1]]), length(x_df)),
                     dimnames = c(dimnames(x_df[[1]]), list(names(x))))

    apply(x_array, c(1, 2), mean)
}