limebit / medmodels

MedModels is a high-speed RWE framework to apply the latest methods from scientific research to medical data.
https://www.medmodels.de
BSD 3-Clause "New" or "Revised" License
4 stars 0 forks source link

Adapt the matching function so it does not use polars dataframes conversion #117

Open MarIniOnz opened 4 months ago

MarIniOnz commented 4 months ago

We should adapt this function from the matching file and make it a feature of MedRecord instead.

What is needed?

Special considerations:

`

def _preprocess_data( self, *, medrecord: MedRecord, control_group: Set[NodeIndex], treated_group: Set[NodeIndex], essential_covariates: MedRecordAttributeInputList, one_hot_covariates: MedRecordAttributeInputList, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """ Prepared the data for the matching algorithms.

    Args:
        medrecord (MedRecord):  MedRecord object containing the data.
        control_group (Set[NodeIndex]): Set of treated subjects.
        treated_group (Set[NodeIndex]): Set of control subjects.
        essential_covariates (MedRecordAttributeInputList):  Covariates
            that are essential for matching
        one_hot_covariates (MedRecordAttributeInputList): Covariates that
            are one-hot encoded for matching

    Returns:
        Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their
            preprocessed covariates
    """
    essential_covariates = [str(covariate) for covariate in essential_covariates]

    if "id" not in essential_covariates:
        essential_covariates.append("id")

    # Dataframe
    data = pl.DataFrame(
        data=[
            {"id": k, **v}
            for k, v in medrecord.node[list(control_group | treated_group)].items()
        ]
    )
    original_columns = data.columns

    # One-hot encode the categorical variables
    data = data.to_dummies(
        columns=[str(covariate) for covariate in one_hot_covariates],
        drop_first=True,
    )
    new_columns = [col for col in data.columns if col not in original_columns]

    # Add to essential covariates the new columns created by one-hot encoding and
    # delete the ones that were one-hot encoded
    essential_covariates.extend(new_columns)
    [essential_covariates.remove(col) for col in one_hot_covariates]
    data = data.select(essential_covariates)

    # Select the sets of treated and control subjects
    data_treated = data.filter(pl.col("id").is_in(treated_group))
    data_control = data.filter(pl.col("id").is_in(control_group))

    return data_treated, data_control

`

JabobKrauskopf commented 4 months ago

@MarIniOnz Could you add an example for a medrecord with one or two nodes of what the output should look like exactly? :)

MarIniOnz commented 4 months ago

MedRecord with nodes: Node1. Attributes: gender: "female", age: 21 Node2. Attributes: gender: "male, age: 45

medrecord.nodes_attributes([node1, node2]) Either: polars dataframe gender age female 21 male 45

or np.array(["female", "male"], [21,45])

or medrecord.nodes_attributes([node1, node2,], hot_encoding= "gender") np.array([0,1], [21,45])

JabobKrauskopf commented 3 months ago

Can be fixed after #146