OHDSI / PatientLevelPrediction

An R package for performing patient level prediction in an observational database in the OMOP Common Data Model.
https://ohdsi.github.io/PatientLevelPrediction
187 stars 88 forks source link

Code base refactor - Discussion #468

Open lhjohn opened 3 months ago

lhjohn commented 3 months ago

A place to discuss the refactor of PLP and get an overview of the current and options for a prospective code base. Currently the project is file-organized and function-based. Below is a "class" diagram of all files and functions in the R folder.

classDiagram
class AdditionalCovariates {
createCohortCovariateSettings
getCohortCovariateData
}
class AndromedaHelperFunctions {
batchRestrict
calculatePrevs
limitCovariatesToPopulation
}
class CalibrationSummary {
getCalibrationSummary
getCalibrationSummary_binary
getCalibrationSummary_survival
}
class CovariateSummary {
aggregateCovariateSummaries
covariateSummary
covariateSummarySubset
createCovariateSubsets
getCovariatesForGroup
}
class CyclopsModels {
createCyclopsModel
filterCovariateIds
fitCyclopsModel
getCV
getVariableImportance
modelTypeToCyclopsModelType
predictCyclops
predictCyclopsType
reparamTransferCoefs
}
class CyclopsSettings {
setCoxModel
setIterativeHardThresholding
setLassoLogisticRegression
}
class DatabaseMigration {
getDataMigrator
migrateDataModel
}
class DataSplitting {
checkInputsSplit
createDefaultSplitSetting
dataSummary
randomSplitter
splitData
subjectSplitter
timeSplitter
}
class DemographicSummary {
getDemographicSummary
getDemographicSummary_binary
getDemographicSummary_survival
}
class DiagnosePlp {
cos_sim
diagnoseMultiplePlp
diagnosePlp
getDiagnostic
getMaxEndDaysFromCovariates
getOutcomeSummary
probastDesign
probastOutcome
probastParticipants
probastPredictors
}
class EvaluatePlp {
evaluatePlp
modelBasedConcordance
}
class EvaluationSummary {
aucWithCi
aucWithoutCi
averagePrecision
brierScore
calculateEStatisticsBinary
calibrationInLarge
calibrationInLargeIntercept
calibrationLine
calibrationWeak
computeAuc
getEvaluationStatistics
getEvaluationStatistics_binary
getEvaluationStatistics_survival
ici
}
class ExternalValidatePlp {
createValidationDesign
createValidationSettings
externalValidateDbPlp
externalValidatePlp
validateExternal
validateModel
}
class ExtractData {
createDatabaseDetails
createRestrictPlpDataSettings
getPlpData
print.plpData
print.summary.plpData
summary.plpData
}
class FeatureEngineering {
calculateStratifiedMeans
createFeatureEngineeringSettings
createRandomForestFeatureSelection
createSplineSettings
createStratifiedImputationSettings
createUnivariateFeatureSelection
featureEngineer
imputeMissingMeans
randomForestFeatureSelection
splineCovariates
splineMap
stratifiedImputeCovariates
univariateFeatureSelection
}
class FeatureImportance {
permute
permutePerf
pfi
}
class Fit {
fitPlp
}
class Formatting {
checkRam
MapIds
toSparseM
}
class GradientBoostingMachine {
fitXgboost
predictXgboost
setGradientBoostingMachine
varImpXgboost
}
class HelperFunctions {
configurePython
createTempModelLoc
cut2
ensure_installed
getOs
is_installed
listAppend
nrow
nrow.default
nrow.tbl
removeInvalidString
setPythonEnvironment
}
class ImportFromCsv {
extractCohortDefinitionsCSV
extractDatabaseListCSV
extractDiagnosticFromCsv
extractObjectFromCsv
getModelDesignCsv
getModelDesignSettingTable
getPerformanceEvaluationCsv
getTableNamesPlp
insertCsvToDatabase
}
class KNN {
fitKNN
predictKnn
setKNN
}
class LearningCurve {
createLearningCurve
getTrainFractions
lcWrapper
learningCurveHelper
plotLearningCurve
}
class LightGBM {
fitLightGBM
predictLightGBM
setLightGBM
varImpLightGBM
}
class Logging {
checkFileExists
closeLog
createLog
createLogSettings
}
class ParamChecks {
checkBoolean
checkHigher
checkHigherEqual
checkInStringVector
checkIsClass
checkLower
checkLowerEqual
checkNotNull
}
class PatientLevelPrediction

class Plotting {
outcomeSurvivalPlot
plotDemographicSummary
plotF1Measure
plotGeneralizability
plotPlp
plotPrecisionRecall
plotPredictedPDF
plotPredictionDistribution
plotPreferencePDF
plotSmoothCalibration
plotSmoothCalibrationLoess
plotSmoothCalibrationRcs
plotSparseCalibration
plotSparseCalibration2
plotSparseRoc
plotVariableScatterplot
}
class PopulationSettings {
createStudyPopulation
createStudyPopulationSettings
getCounts
getCounts2
}
class Predict {
applyFeatureengineering
applyTidyCovariateData
predictPlp
}
class PredictionDistribution {
getPredictionDistribution
getPredictionDistribution_binary
getPredictionDistribution_survival
}
class PreprocessingData {
createPreprocessSettings
preprocessData
}
class RClassifier {
applyCrossValidationInR
fitRclassifier
}
class Recalibration {
inverseLog
logFunct
recalibratePlp
recalibratePlpRefit
recalibrationInTheLarge
weakRecalibration
}
class RunMultiplePlp {
convertToJson
createModelDesign
loadPlpAnalysesJson
runMultiplePlp
savePlpAnalysesJson
validateMultiplePlp
}
class RunPlp {
runPlp
}
class RunPlpHelpers {
checkInputs
createDefaultExecuteSettings
createExecuteSettings
printHeader
}
class Sampling {
createSampleSettings
overSampleData
sameData
sampleData
underSampleData
}
class SaveLoadPlp {
applyMinCellCount
extractDatabaseToCsv
getPlpSensitiveColumns
loadPlpData
loadPlpModel
loadPlpResult
loadPlpShareable
loadPrediction
removeCellCount
removeList
saveModelPart
savePlpData
savePlpModel
savePlpResult
savePlpShareable
savePrediction
}
class Simulation {
simulatePlpData
}
class SklearnClassifier {
checkPySettings
computeGridPerformance
fitPythonModel
fitSklearn
gridCvPython
predictPythonSklearn
predictValues
}
class SklearnClassifierHelpers {
listCartesian
}
class SklearnClassifierSettings {
AdaBoostClassifierInputs
DecisionTreeClassifierInputs
GaussianNBInputs
MLPClassifierInputs
RandomForestClassifierInputs
setAdaBoost
setDecisionTree
setMLP
setNaiveBayes
setRandomForest
setSVM
SVCInputs
}
class SklearnToJson {
deSerializeAdaboost
deSerializeCsrMatrix
deSerializeDecisionTree
deSerializeMlp
deSerializeNaiveBayes
deSerializeRandomForest
deSerializeSVM
deSerializeTree
serializeAdaboost
serializeCsrMatrix
serializeDecisionTree
serializeMLP
serializeNaiveBayes
serializeRandomForest
serializeSVM
serializeTree
sklearnFromJson
sklearnToJson
}
class ThresholdSummary {
accuracy
checkToByTwoTableInputs
diagnosticOddsRatio
f1Score
falseDiscoveryRate
falseNegativeRate
falseOmissionRate
falsePositiveRate
getThresholdSummary
getThresholdSummary_binary
getThresholdSummary_survival
negativeLikelihoodRatio
negativePredictiveValue
positiveLikelihoodRatio
positivePredictiveValue
sensitivity
specificity
stdca
}
class uploadToDatabase {
addCohort
addDatabase
addModel
addMultipleRunPlpToDatabase
addRunPlpToDatabase
checkJson
checkTable
cleanNum
createDatabaseList
createDatabaseSchemaSettings
createPlpResultTables
deleteTables
enc
getCohortDef
getPlpResultTables
getResultLocations
insertModelInDatabase
insertResultsToSqlite
insertRunPlpToSqlite
}
class uploadToDatabaseDiagnostics {
addDiagnosePlpToDatabase
addDiagnostic
addMultipleDiagnosePlpToDatabase
addResultTable
insertDiagnosisToDatabase
}
class uploadToDatabaseModelDesign {
addCovariateSetting
addFESetting
addModelDesign
addModelSetting
addPlpDataSetting
addPopulationSetting
addSampleSetting
addSplitSettings
addTar
addTidySetting
insertModelDesignInDatabase
insertModelDesignSettings
orderJson
}
class uploadToDatabasePerformance {
addAttrition
addCalibrationSummary
addCovariateSummary
addDemographicSummary
addEvaluation
addEvaluationStatistics
addPerformance
addPredictionDistribution
addThresholdSummary
checkResultExists
getColumnNames
insertPerformanceInDatabase
}
class ViewShinyPlp {
viewDatabaseResultPlp
viewMultiplePlp
viewPlp
viewPlps
}

Related resources: Draft PR for new model API: https://github.com/OHDSI/PatientLevelPrediction/pull/462

lhjohn commented 3 months ago

To design PLP system inspired by mlr3, we can organize PLP components into mlr3 building blocks:

  1. Learner: Corresponds to model type in PLP.
  2. Task: Could correspond to study population and other cohort parameters.
  3. Resample: Could correspond to data splitting in PLP.
  4. Measure: Corresponds to evaluation functions in PLP.
  5. Prediction: Could correspond to the model object in PLP, used for internal and external validation.
  6. Data: Although not considered a building block in mlr3, useful to represent the data object. Exists as DataBackend class in mlr3.

Forcing our existing PLP functions into mlr3 building blocks could look something like this:

PatientLevelPrediction:

classDiagram
class PatientLevelPrediction

PatientLevelPrediction --> HelperFunctions
PatientLevelPrediction --> Fit
PatientLevelPrediction --> Logging
PatientLevelPrediction --> ParamChecks
PatientLevelPrediction --> DatabaseMigration
PatientLevelPrediction --> RunMultiplePlp
PatientLevelPrediction --> RunPlp
PatientLevelPrediction --> RunPlpHelpers
PatientLevelPrediction --> SaveLoadPlp
PatientLevelPrediction --> LearningCurve

class HelperFunctions {
  +configurePython()
  +createTempModelLoc()
  +cut2()
  +ensure_installed()
  +getOs()
  +is_installed()
  +listAppend()
  +nrow()
  +nrow.default()
  +nrow.tbl()
  +removeInvalidString()
  +setPythonEnvironment()
}
class Fit {
  +fitPlp()
}
class Logging {
  +checkFileExists()
  +closeLog()
  +createLog()
  +createLogSettings()
}
class ParamChecks {
  +checkBoolean()
  +checkHigher()
  +checkHigherEqual()
  +checkInStringVector()
  +checkIsClass()
  +checkLower()
  +checkLowerEqual()
  +checkNotNull()
}
class DatabaseMigration {
  +getDataMigrator()
  +migrateDataModel()
}
class RunMultiplePlp {
  +convertToJson()
  +createModelDesign()
  +loadPlpAnalysesJson()
  +runMultiplePlp()
  +savePlpAnalysesJson()
  +validateMultiplePlp()
}
class RunPlp {
  +runPlp()
}
class RunPlpHelpers {
  +checkInputs()
  +createDefaultExecuteSettings()
  +createExecuteSettings()
  +printHeader()
}
class SaveLoadPlp {
  +applyMinCellCount()
  +extractDatabaseToCsv()
  +getPlpSensitiveColumns()
  +loadPlpData()
  +loadPlpModel()
  +loadPlpResult()
  +loadPlpShareable()
  +loadPrediction()
  +removeCellCount()
  +removeList()
  +saveModelPart()
  +savePlpData()
  +savePlpModel()
  +savePlpResult()
  +savePlpShareable()
  +savePrediction()
}
class LearningCurve {
  +createLearningCurve()
  +getTrainFractions()
  +lcWrapper()
  +learningCurveHelper()
  +plotLearningCurve()
}

Data:

classDiagram
class Data

Data --> ExtractData
Data --> Simulation
Data --> PreprocessingData
Data --> FeatureEngineering
Data --> FeatureImportance
Data --> Formatting
Data --> AdditionalCovariates
Data --> AndromedaHelperFunctions

class ExtractData {
  +createDatabaseDetails()
  +createRestrictPlpDataSettings()
  +getPlpData()
  +print.plpData()
  +print.summary.plpData()
  +summary.plpData()
}
class Simulation {
  +simulatePlpData()
}
class PreprocessingData {
  +createPreprocessSettings()
  +preprocessData()
}
class FeatureEngineering {
  +calculateStratifiedMeans()
  +createFeatureEngineeringSettings()
  +createRandomForestFeatureSelection()
  +createSplineSettings()
  +createStratifiedImputationSettings()
  +createUnivariateFeatureSelection()
  +featureEngineer()
  +imputeMissingMeans()
  +randomForestFeatureSelection()
  +splineCovariates()
  +splineMap()
  +stratifiedImputeCovariates()
  +univariateFeatureSelection()
}
class FeatureImportance {
  +permute()
  +permutePerf()
  +pfi()
}
class Formatting {
  +checkRam()
  +MapIds()
  +toSparseM()
}
class AdditionalCovariates {
  +createCohortCovariateSettings()
  +getCohortCovariateData()
}
class AndromedaHelperFunctions {
  +batchRestrict()
  +calculatePrevs()
  +limitCovariatesToPopulation()
}

Resample:

classDiagram
class Resample

class Sampling {
  +createSampleSettings()
  +overSampleData()
  +sameData()
  +sampleData()
  +underSampleData()
}

class DataSplitting {
  +checkInputsSplit()
  +createDefaultSplitSetting()
  +dataSummary()
  +randomSplitter()
  +splitData()
  +subjectSplitter()
  +timeSplitter()
}

Resample --> Sampling
Resample --> DataSplitting

Task:

classDiagram
class Task

Task --> PopulationSettings
Task --> DiagnosePlp

class PopulationSettings {
  +createStudyPopulation()
  +createStudyPopulationSettings()
  +getCounts()
  +getCounts2()
}
class DiagnosePlp {
  +cos_sim()
  +diagnoseMultiplePlp()
  +diagnosePlp()
  +getDiagnostic()
  +getMaxEndDaysFromCovariates()
  +getOutcomeSummary()
  +probastDesign()
  +probastOutcome()
  +probastParticipants()
  +probastPredictors()
}

Learner:

classDiagram
class Learner

Learner --> SklearnToJson
Learner --> SklearnClassifierSettings
Learner --> SklearnClassifierHelpers
Learner --> SklearnClassifier
Learner --> RClassifier
Learner --> KNN
Learner --> LightGBM
Learner --> GradientBoostingMachine
Learner --> CyclopsModels
Learner --> CyclopsSettings

class SklearnToJson {
  +deSerializeAdaboost()
  +deSerializeCsrMatrix()
  +deSerializeDecisionTree()
  +deSerializeMlp()
  +deSerializeNaiveBayes()
  +deSerializeRandomForest()
  +deSerializeSVM()
  +deSerializeTree()
  +serializeAdaboost()
  +serializeCsrMatrix()
  +serializeDecisionTree()
  +serializeMLP()
  +serializeNaiveBayes()
  +serializeRandomForest()
  +serializeSVM()
  +serializeTree()
  +sklearnFromJson()
  +sklearnToJson()
}
class SklearnClassifierSettings {
  +AdaBoostClassifierInputs()
  +DecisionTreeClassifierInputs()
  +GaussianNBInputs()
  +MLPClassifierInputs()
  +RandomForestClassifierInputs()
  +setAdaBoost()
  +setDecisionTree()
  +setMLP()
  +setNaiveBayes()
  +setRandomForest()
  +setSVM()
  +SVCInputs()
}
class SklearnClassifierHelpers {
  +listCartesian()
}
class SklearnClassifier {
  +checkPySettings()
  +computeGridPerformance()
  +fitPythonModel()
  +fitSklearn()
  +gridCvPython()
  +predictPythonSklearn()
  +predictValues()
}
class RClassifier {
  +applyCrossValidationInR()
  +fitRclassifier()
}
class KNN {
  +fitKNN()
  +predictKnn()
  +setKNN()
}
class LightGBM {
  +fitLightGBM()
  +predictLightGBM()
  +setLightGBM()
  +varImpLightGBM()
}
class GradientBoostingMachine {
  +fitXgboost()
  +predictXgboost()
  +setGradientBoostingMachine()
  +varImpXgboost()
}
class CyclopsModels {
  +createCyclopsModel()
  +filterCovariateIds()
  +fitCyclopsModel()
  +getCV()
  +getVariableImportance()
  +modelTypeToCyclopsModelType()
  +predictCyclops()
  +predictCyclopsType()
  +reparamTransferCoefs()
}
class CyclopsSettings {
  +setCoxModel()
  +setIterativeHardThresholding()
  +setLassoLogisticRegression()
}

Measure:

classDiagram
class Measure

Measure --> ViewShinyPlp
Measure --> uploadToDatabasePerformance
Measure --> uploadToDatabase
Measure --> uploadToDatabaseDiagnostics
Measure --> uploadToDatabaseModelDesign
Measure --> ThresholdSummary
Measure --> PredictionDistribution
Measure --> Plotting
Measure --> CovariateSummary
Measure --> EvaluatePlp
Measure --> EvaluationSummary
Measure --> DemographicSummary
Measure --> CalibrationSummary
Measure --> ImportFromCsv

class ViewShinyPlp {
  +viewDatabaseResultPlp()
  +viewMultiplePlp()
  +viewPlp()
  +viewPlps()
}
class uploadToDatabasePerformance {
  +addAttrition()
  +addCalibrationSummary()
  +addCovariateSummary()
  +addDemographicSummary()
  +addEvaluation()
  +addEvaluationStatistics()
  +addPerformance()
  +addPredictionDistribution()
  +addThresholdSummary()
  +checkResultExists()
  +getColumnNames()
  +insertPerformanceInDatabase()
}
class uploadToDatabase {
  +addCohort()
  +addDatabase()
  +addModel()
  +addMultipleRunPlpToDatabase()
  +addRunPlpToDatabase()
  +checkJson()
  +checkTable()
  +cleanNum()
  +createDatabaseList()
  +createDatabaseSchemaSettings()
  +createPlpResultTables()
  +deleteTables()
  +enc()
  +getCohortDef()
  +getPlpResultTables()
  +getResultLocations()
  +insertModelInDatabase()
  +insertResultsToSqlite()
  +insertRunPlpToSqlite()
}
class uploadToDatabaseDiagnostics {
  +addDiagnosePlpToDatabase()
  +addDiagnostic()
  +addMultipleDiagnosePlpToDatabase()
  +addResultTable()
  +insertDiagnosisToDatabase()
}
class uploadToDatabaseModelDesign {
  +addCovariateSetting()
  +addFESetting()
  +addModelDesign()
  +addModelSetting()
  +addPlpDataSetting()
  +addPopulationSetting()
  +addSampleSetting()
  +addSplitSettings()
  +addTar()
  +addTidySetting()
  +insertModelDesignInDatabase()
  +insertModelDesignSettings()
  +orderJson()
}
class ThresholdSummary {
  +accuracy()
  +checkToByTwoTableInputs()
  +diagnosticOddsRatio()
  +f1Score()
  +falseDiscoveryRate()
  +falseNegativeRate()
  +falseOmissionRate()
  +falsePositiveRate()
  +getThresholdSummary()
  +getThresholdSummary_binary()
  +getThresholdSummary_survival()
  +negativeLikelihoodRatio()
  +negativePredictiveValue()
  +positiveLikelihoodRatio()
  +positivePredictiveValue()
  +sensitivity()
  +specificity()
  +stdca()
}
class PredictionDistribution {
  +getPredictionDistribution()
  +getPredictionDistribution_binary()
  +getPredictionDistribution_survival()
}
class Plotting {
  +outcomeSurvivalPlot()
  +plotDemographicSummary()
  +plotF1Measure()
  +plotGeneralizability()
  +plotPlp()
  +plotPrecisionRecall()
  +plotPredictedPDF()
  +plotPredictionDistribution()
  +plotPreferencePDF()
  +plotSmoothCalibration()
  +plotSmoothCalibrationLoess()
  +plotSmoothCalibrationRcs()
  +plotSparseCalibration()
  +plotSparseCalibration2()
  +plotSparseRoc()
  +plotVariableScatterplot()
}
class CovariateSummary {
  +aggregateCovariateSummaries()
  +covariateSummary()
  +covariateSummarySubset()
  +createCovariateSubsets()
  +getCovariatesForGroup()
}
class EvaluatePlp {
  +evaluatePlp()
  +modelBasedConcordance()
}
class EvaluationSummary {
  +aucWithCi()
  +aucWithoutCi()
  +averagePrecision()
  +brierScore()
  +calculateEStatisticsBinary()
  +calibrationInLarge()
  +calibrationInLargeIntercept()
  +calibrationLine()
  +calibrationWeak()
  +computeAuc()
  +getEvaluationStatistics()
  +getEvaluationStatistics_binary()
  +getEvaluationStatistics_survival()
  +ici()
}
class DemographicSummary {
  +getDemographicSummary()
  +getDemographicSummary_binary()
  +getDemographicSummary_survival()
}
class CalibrationSummary {
  +getCalibrationSummary()
  +getCalibrationSummary_binary()
  +getCalibrationSummary_survival()
}
class ImportFromCsv {
  +extractCohortDefinitionsCSV()
  +extractDatabaseListCSV()
  +extractDiagnosticFromCsv()
  +extractObjectFromCsv()
  +getModelDesignCsv()
  +getModelDesignSettingTable()
  +getPerformanceEvaluationCsv()
  +getTableNamesPlp()
  +insertCsvToDatabase()
}

Prediction:

classDiagram
class Prediction

Prediction --> ExternalValidatePlp
Prediction --> Recalibration
Prediction --> Predict

class ExternalValidatePlp {
  +createValidationDesign()
  +createValidationSettings()
  +externalValidateDbPlp()
  +externalValidatePlp()
  +validateExternal()
  +validateModel()
}
class Recalibration {
  +inverseLog()
  +logFunct()
  +recalibratePlp()
  +recalibratePlpRefit()
  +recalibrationInTheLarge()
  +weakRecalibration()
}
class Predict {
  +applyFeatureengineering()
  +applyTidyCovariateData()
  +predictPlp()
}
egillax commented 2 months ago

For information. TidyModels uses Parsnip to provide model interfaces. They describe their design here:

https://github.com/tidymodels/parsnip/tree/main/R#readme

They seem to be using function calls although it is a bit complicated.