PNNL-CompBio / coderdata

Automation scripts and benchmark dataset package for cancer drug prediction deep learning models.
Other
11 stars 3 forks source link

Train test validate #244

Closed ymahlich closed 1 week ago

ymahlich commented 1 week ago

PR for the training / testing and validation set generation scripts.

PR adds coderdata.split.splitter which contains:

train_test_validate()

This is the main (and "public") function of the submodule. The function enables the generation of train/test/validation splits. Returns 3 individual CoderData objects, one each for tain, test and validate. Arguments that can modify how the individual splits are generated are:

_create_classes()

Internal "private" helper function to internally create classes that are needed for the stratification. Arguments (besides the dataset) are:

_filter()

Internal "private" helper function that aids in creating filtered subsets of the reference CoderData object which only contain data points that pertain to the individual train / test & validate sets.

Example call:

import coderdata as cd
data = cd.DatasetLoader('beataml')
train, test, validate = cd.train_test_validate(
    data,
    split_type='cancer-blind',
    ratio=[8,1,1],
    stratify_by='fit_auc',
    random_state=42,
    num_classes=5,
    )

The call detailed above would generate a training, testing & validation CoderData object, based on the BeatAML dataset. The splits are generated such that the individual sets are cancer-blind, i.e. cell lines used to test drugs on in train are not present in either test or validate and vice versa. Ratios for the split sizes are 8:1:1 for train/test/validate. The split is done with stratification by using fit_auc as a reference. Stratification also is done by internally generating 5 classes (num_classes=5) as well as using "quantiles" (does not need to defined in the function call since this is the default behavior - if evenly spaced classes are desired set quantiles=False). Finally the seed for the randomization is set to 42 to generate a reproducible split (random_state=42).

What this PR DOESN'T do:

Implement a Class function call akin to dataset.train_test_validate() that can be directly called based on the loaded CoderData object.

sgosline commented 1 week ago

I'm good to merge for now!