BRANCHlab / metasnf

Scalable subtyping with similarity network fusion
https://branchlab.github.io/metasnf/
Other
7 stars 0 forks source link

Full data list for label propagation #6

Closed apdlbalb closed 7 months ago

apdlbalb commented 8 months ago

Would you consider baking in a call to sort() to handle the train_subjects and test_subjects arguments in the generate_data_list() function? I used sample() on my patient IDs to create the train and test splits so my vectors were unordered!

Thank you!

pvelayudhan commented 8 months ago

This brings up a very good point - I think sorting in this package has always felt a little clunky and could use a bit of an overhaul. Right now generate_data_list auto-sorts across all patients by their UID, which was a bit of a hacky way to handle input dataframes being loaded in with non-identical patient orders. As a result, getting the order correct for label propagation is unnecessarily complicated.

On top of that, the label propagation function has all the information it needs to know who the train and test subjects are without the user necessarily getting the order of things correct, so it should be less finicky to work with.

I'm thinking of moving to:

  1. Easier utility functions for resorting patient order in existing data_lists (e.g., resort to match a vector of patient IDs)
  2. Have functions where patient order is important (like label propagation) act more helpful when the order isn't correct, e.g. autosorting like you mention and raising a warning about patient order
  3. Getting rid of any default automatic patient sorting in generate_data_list and elsewhere so that users have more transparent control over the order of their subjects.

I will try and get something sorted out for this shortly!

pvelayudhan commented 8 months ago

Edit: the code below doesn't work (lol) - comment below has the correct version

@apdlbalb The updated label propagation function (lp_solutions_matrix) is below - it effectively does the sorting you mentioned and now no longer cares about subject order at all! As long as you provide a train_solutions_matrix built from running batch_snf on your training subjects and a full_data_list that has the data of both training and testing set subjects, the function should run just fine. It is not yet pushed to main - I want to replace the label propagation part of the "less simple example" vignette into its own thing with the updated function, but will do that some time later this week. It is backwards compatible with the previous lp_row function.

#' Label propagate cluster solutions to unclustered subjects
#'
#' Given a solutions_matrix derived from training subjects and a full_data_list
#' containing both training and test subjects, re-run SNF to generate a total
#' affinity matrix of both train and subjects and use the label propagation
#' algorithm to assigned predicted clusters to test subjects.
#'
#' @param solutions_matrix A solutions_matrix. The propagation algorithm is
#' slow and should be used for validating a top or top few meaningful chosen
#' clustering solutions. It is advisable to use only a small subset of rows
#' from the original solutions_matrix for label propagation.
#' @param clust_algs_list If a custom clustering algorithm list was used during
#' the original batch_snf call, include that clust_algs_list here as well.
#' @param distance_metrics_list Like above - the distance_metrics_list (if any)
#' that was used for the original batch_snf call.
#' @param weights_matrix Like above.
#'
#' @return labeled_df a dataframe containing a column for subjectkeys,
#' a column for whether the subject was in the train (original) or test (held
#' out) set, and one column per row of the solutions matrix indicating the
#' original and propagated clusters.
#'
#' @export
lp_solutions_matrix <- function(train_solutions_matrix,
                                full_data_list,
                                clust_algs_list = NULL,
                                distance_metrics_list = NULL,
                                weights_matrix = NULL) {
    ###########################################################################
    # 1. Reorder data_list subjects
    ###########################################################################
    train_subjects <- colnames(subs(train_solutions_matrix))[-1]
    all_subjects <- full_data_list[[1]][[1]]$"subjectkey"
    # Quick check to make sure the train subjects are all in the full list
    if (!all(train_subjects %in% all_subjects)) {
        stop(
            "Some of the subjects with known clusters in the",
            "train_solutions_matrix are not present in the full_data_list."
        )
    }
    test_subjects <- all_subjects[!all_subjects %in% train_subjects]
    lp_ordered_subjects <- c(train_subjects, test_subjects)
    full_data_list <- reorder_dl_subs(full_data_list, lp_ordered_subjects)
    ###########################################################################
    # 2. Prepare vectors containing the names of the train and test subjects
    ###########################################################################
    n_train <- length(train_subjects)
    n_test <- length(test_subjects)
    group_vec <- c(rep("train", n_train), rep("test", n_test))
    ###########################################################################
    # 3. SNF of the full data list
    ###########################################################################
    ###########################################################################
    ## 3-1. Creation of distance_metrics_list, if it does not already exist
    ###########################################################################
    if (is.null(distance_metrics_list)) {
        distance_metrics_list <- generate_distance_metrics_list()
    }
    ###########################################################################
    ## 3-2. Create (or check) weights_matrix
    ###########################################################################
    if (is.null(weights_matrix)) {
        weights_matrix <- generate_weights_matrix(
            full_data_list,
            nrow = nrow(train_solutions_matrix)
        )
    } else {
        if (nrow(weights_matrix) != nrow(train_solutions_matrix)) {
            stop(
                paste0(
                    "Weights_matrix and train_solutions_matrix",
                    " should have the same number of rows."
                )
            )
        }
    }
    ###########################################################################
    ## 3-3. SNF one row at a time
    ###########################################################################
    for (i in seq_len(nrow(train_solutions_matrix))) {
        print(
            paste0(
                "Processing row ", i, " of ",
                nrow(train_solutions_matrix), "..."
            )
        )
        current_row <- train_solutions_matrix[i, ]
        sig <- paste0(current_row$"row_id")
        reduced_dl <- drop_inputs(current_row, full_data_list)
        scheme <- current_row$"snf_scheme"
        k <- current_row$"k"
        alpha <- current_row$"alpha"
        t <- current_row$"t"
        cont_dist <- current_row$"cont_dist"
        disc_dist <- current_row$"disc_dist"
        ord_dist <- current_row$"ord_dist"
        cat_dist <- current_row$"cat_dist"
        mix_dist <- current_row$"mix_dist"
        cont_dist_fn <- distance_metrics_list$"continuous_distance"[[cont_dist]]
        disc_dist_fn <- distance_metrics_list$"discrete_distance"[[disc_dist]]
        ord_dist_fn <- distance_metrics_list$"ordinal_distance"[[ord_dist]]
        cat_dist_fn <- distance_metrics_list$"categorical_distance"[[cat_dist]]
        mix_dist_fn <- distance_metrics_list$"mixed_distance"[[mix_dist]]
        weights_row <- weights_matrix[i, , drop = FALSE]
        #######################################################################
        # The actual SNF
        #######################################################################
        full_fused_network <- snf_step(
            reduced_dl,
            scheme = scheme,
            k = k,
            alpha = alpha,
            t = t,
            cont_dist_fn = cont_dist_fn,
            disc_dist_fn = disc_dist_fn,
            ord_dist_fn = ord_dist_fn,
            cat_dist_fn = cat_dist_fn,
            mix_dist_fn = mix_dist_fn,
            weights_row = weights_row
        )
        full_fused_network <- full_fused_network[
            lp_ordered_subjects,
            lp_ordered_subjects
        ]
        clusters <- get_clusters(current_row)
        #######################################################################
        # Label propagation
        #######################################################################
        propagated_labels <- label_prop(full_fused_network, clusters)
        if (i == 1) {
            labeled_df <- data.frame(
                subjectkey = all_subjects,
                group = group_vec,
                cluster = propagated_labels
            )
            names <- colnames(labeled_df)
            names[which(names == "cluster")] <- sig
            colnames(labeled_df) <- names
        } else {
            current_df <- data.frame(
                subjectkey = all_subjects,
                group = group_vec,
                cluster = propagated_labels
            )
            names <- colnames(current_df)
            names[which(names == "cluster")] <- sig
            colnames(current_df) <- names
            labeled_df <- dplyr::inner_join(
                labeled_df,
                current_df,
                by = c("subjectkey", "group")
            )
        }
    }
    return(labeled_df)
}
pvelayudhan commented 8 months ago

Oops, it also needs this to run:

#' Reorder the subjects in a data_list
#'
#' @param data_list data_list to reorder
#' @param ordered_subjects A vector of the subjectkey values in the data_list
#' in the desired order of the sorted data_list.
#'
#' @export
reorder_dl_subs <- function(data_list, ordered_subjects) {
    data_list <- data_list |>
        lapply(
            function(x) {
                index <- match(x$"data"$"subjectkey", ordered_subjects)
                x$"data" <- x$"data"[order(index), ]
                return(x)
            }
        )
    return(data_list)
}
pvelayudhan commented 8 months ago

Well I'm glad I didn't end up pushing anything - that last label propagation function didn't sort the results properly at all.

This should be the correct:

#' Label propagate cluster solutions to unclustered subjects
#'
#' Given a solutions_matrix derived from training subjects and a full_data_list
#' containing both training and test subjects, re-run SNF to generate a total
#' affinity matrix of both train and subjects and use the label propagation
#' algorithm to assigned predicted clusters to test subjects.
#'
#' @param solutions_matrix A solutions_matrix. The propagation algorithm is
#' slow and should be used for validating a top or top few meaningful chosen
#' clustering solutions. It is advisable to use only a small subset of rows
#' from the original solutions_matrix for label propagation.
#' @param clust_algs_list If a custom clustering algorithm list was used during
#' the original batch_snf call, include that clust_algs_list here as well.
#' @param distance_metrics_list Like above - the distance_metrics_list (if any)
#' that was used for the original batch_snf call.
#' @param weights_matrix Like above.
#'
#' @return labeled_df a dataframe containing a column for subjectkeys,
#' a column for whether the subject was in the train (original) or test (held
#' out) set, and one column per row of the solutions matrix indicating the
#' original and propagated clusters.
#'
#' @export
lp_solutions_matrix <- function(train_solutions_matrix,
                                full_data_list,
                                clust_algs_list = NULL,
                                distance_metrics_list = NULL,
                                weights_matrix = NULL) {
    ###########################################################################
    # 1. Reorder data_list subjects
    ###########################################################################
    train_subjects <- colnames(subs(train_solutions_matrix))[-1]
    all_subjects <- full_data_list[[1]][[1]]$"subjectkey"
    # Quick check to make sure the train subjects are all in the full list
    if (!all(train_subjects %in% all_subjects)) {
        stop(
            "Some of the subjects with known clusters in the",
            "train_solutions_matrix are not present in the full_data_list."
        )
    }
    test_subjects <- all_subjects[!all_subjects %in% train_subjects]
    lp_ordered_subjects <- c(train_subjects, test_subjects)
    full_data_list <- reorder_dl_subs(full_data_list, lp_ordered_subjects)
    ###########################################################################
    # 2. Prepare vectors containing the names of the train and test subjects
    ###########################################################################
    n_train <- length(train_subjects)
    n_test <- length(test_subjects)
    group_vec <- c(rep("train", n_train), rep("test", n_test))
    ###########################################################################
    # 3. SNF of the full data list
    ###########################################################################
    ###########################################################################
    ## 3-1. Creation of distance_metrics_list, if it does not already exist
    ###########################################################################
    if (is.null(distance_metrics_list)) {
        distance_metrics_list <- generate_distance_metrics_list()
    }
    ###########################################################################
    ## 3-2. Create (or check) weights_matrix
    ###########################################################################
    if (is.null(weights_matrix)) {
        weights_matrix <- generate_weights_matrix(
            full_data_list,
            nrow = nrow(train_solutions_matrix)
        )
    } else {
        if (nrow(weights_matrix) != nrow(train_solutions_matrix)) {
            stop(
                paste0(
                    "Weights_matrix and train_solutions_matrix",
                    " should have the same number of rows."
                )
            )
        }
    }
    ###########################################################################
    ## 3-3. SNF one row at a time
    ###########################################################################
    for (i in seq_len(nrow(train_solutions_matrix))) {
        print(
            paste0(
                "Processing row ", i, " of ",
                nrow(train_solutions_matrix), "..."
            )
        )
        current_row <- train_solutions_matrix[i, ]
        sig <- paste0(current_row$"row_id")
        reduced_dl <- drop_inputs(current_row, full_data_list)
        scheme <- current_row$"snf_scheme"
        k <- current_row$"k"
        alpha <- current_row$"alpha"
        t <- current_row$"t"
        cont_dist <- current_row$"cont_dist"
        disc_dist <- current_row$"disc_dist"
        ord_dist <- current_row$"ord_dist"
        cat_dist <- current_row$"cat_dist"
        mix_dist <- current_row$"mix_dist"
        cont_dist_fn <- distance_metrics_list$"continuous_distance"[[cont_dist]]
        disc_dist_fn <- distance_metrics_list$"discrete_distance"[[disc_dist]]
        ord_dist_fn <- distance_metrics_list$"ordinal_distance"[[ord_dist]]
        cat_dist_fn <- distance_metrics_list$"categorical_distance"[[cat_dist]]
        mix_dist_fn <- distance_metrics_list$"mixed_distance"[[mix_dist]]
        weights_row <- weights_matrix[i, , drop = FALSE]
        #######################################################################
        # The actual SNF
        #######################################################################
        full_fused_network <- snf_step(
            reduced_dl,
            scheme = scheme,
            k = k,
            alpha = alpha,
            t = t,
            cont_dist_fn = cont_dist_fn,
            disc_dist_fn = disc_dist_fn,
            ord_dist_fn = ord_dist_fn,
            cat_dist_fn = cat_dist_fn,
            mix_dist_fn = mix_dist_fn,
            weights_row = weights_row
        )
        full_fused_network <- full_fused_network[
            lp_ordered_subjects,
            lp_ordered_subjects
        ]
        clusters <- get_clusters(current_row)
        #######################################################################
        # Label propagation
        #######################################################################
        propagated_labels <- label_prop(full_fused_network, clusters)
        if (i == 1) {
            labeled_df <- data.frame(
                subjectkey = c(train_subjects, test_subjects),
                group = group_vec,
                cluster = propagated_labels
            )
            names <- colnames(labeled_df)
            names[which(names == "cluster")] <- sig
            colnames(labeled_df) <- names
        } else {
            current_df <- data.frame(
                subjectkey = c(train_subjects, test_subjects),
                group = group_vec,
                cluster = propagated_labels
            )
            names <- colnames(current_df)
            names[which(names == "cluster")] <- sig
            colnames(current_df) <- names
            labeled_df <- dplyr::inner_join(
                labeled_df,
                current_df,
                by = c("subjectkey", "group")
            )
        }
    }
    return(labeled_df)
}
pvelayudhan commented 7 months ago

lp_solutions_matrix function should properly handle label propagation now without any fuss!