Open lhjohn opened 3 months ago
To design PLP system inspired by mlr3, we can organize PLP components into mlr3 building blocks:
Forcing our existing PLP functions into mlr3 building blocks could look something like this:
PatientLevelPrediction:
classDiagram
class PatientLevelPrediction
PatientLevelPrediction --> HelperFunctions
PatientLevelPrediction --> Fit
PatientLevelPrediction --> Logging
PatientLevelPrediction --> ParamChecks
PatientLevelPrediction --> DatabaseMigration
PatientLevelPrediction --> RunMultiplePlp
PatientLevelPrediction --> RunPlp
PatientLevelPrediction --> RunPlpHelpers
PatientLevelPrediction --> SaveLoadPlp
PatientLevelPrediction --> LearningCurve
class HelperFunctions {
+configurePython()
+createTempModelLoc()
+cut2()
+ensure_installed()
+getOs()
+is_installed()
+listAppend()
+nrow()
+nrow.default()
+nrow.tbl()
+removeInvalidString()
+setPythonEnvironment()
}
class Fit {
+fitPlp()
}
class Logging {
+checkFileExists()
+closeLog()
+createLog()
+createLogSettings()
}
class ParamChecks {
+checkBoolean()
+checkHigher()
+checkHigherEqual()
+checkInStringVector()
+checkIsClass()
+checkLower()
+checkLowerEqual()
+checkNotNull()
}
class DatabaseMigration {
+getDataMigrator()
+migrateDataModel()
}
class RunMultiplePlp {
+convertToJson()
+createModelDesign()
+loadPlpAnalysesJson()
+runMultiplePlp()
+savePlpAnalysesJson()
+validateMultiplePlp()
}
class RunPlp {
+runPlp()
}
class RunPlpHelpers {
+checkInputs()
+createDefaultExecuteSettings()
+createExecuteSettings()
+printHeader()
}
class SaveLoadPlp {
+applyMinCellCount()
+extractDatabaseToCsv()
+getPlpSensitiveColumns()
+loadPlpData()
+loadPlpModel()
+loadPlpResult()
+loadPlpShareable()
+loadPrediction()
+removeCellCount()
+removeList()
+saveModelPart()
+savePlpData()
+savePlpModel()
+savePlpResult()
+savePlpShareable()
+savePrediction()
}
class LearningCurve {
+createLearningCurve()
+getTrainFractions()
+lcWrapper()
+learningCurveHelper()
+plotLearningCurve()
}
Data:
classDiagram
class Data
Data --> ExtractData
Data --> Simulation
Data --> PreprocessingData
Data --> FeatureEngineering
Data --> FeatureImportance
Data --> Formatting
Data --> AdditionalCovariates
Data --> AndromedaHelperFunctions
class ExtractData {
+createDatabaseDetails()
+createRestrictPlpDataSettings()
+getPlpData()
+print.plpData()
+print.summary.plpData()
+summary.plpData()
}
class Simulation {
+simulatePlpData()
}
class PreprocessingData {
+createPreprocessSettings()
+preprocessData()
}
class FeatureEngineering {
+calculateStratifiedMeans()
+createFeatureEngineeringSettings()
+createRandomForestFeatureSelection()
+createSplineSettings()
+createStratifiedImputationSettings()
+createUnivariateFeatureSelection()
+featureEngineer()
+imputeMissingMeans()
+randomForestFeatureSelection()
+splineCovariates()
+splineMap()
+stratifiedImputeCovariates()
+univariateFeatureSelection()
}
class FeatureImportance {
+permute()
+permutePerf()
+pfi()
}
class Formatting {
+checkRam()
+MapIds()
+toSparseM()
}
class AdditionalCovariates {
+createCohortCovariateSettings()
+getCohortCovariateData()
}
class AndromedaHelperFunctions {
+batchRestrict()
+calculatePrevs()
+limitCovariatesToPopulation()
}
Resample:
classDiagram
class Resample
class Sampling {
+createSampleSettings()
+overSampleData()
+sameData()
+sampleData()
+underSampleData()
}
class DataSplitting {
+checkInputsSplit()
+createDefaultSplitSetting()
+dataSummary()
+randomSplitter()
+splitData()
+subjectSplitter()
+timeSplitter()
}
Resample --> Sampling
Resample --> DataSplitting
Task:
classDiagram
class Task
Task --> PopulationSettings
Task --> DiagnosePlp
class PopulationSettings {
+createStudyPopulation()
+createStudyPopulationSettings()
+getCounts()
+getCounts2()
}
class DiagnosePlp {
+cos_sim()
+diagnoseMultiplePlp()
+diagnosePlp()
+getDiagnostic()
+getMaxEndDaysFromCovariates()
+getOutcomeSummary()
+probastDesign()
+probastOutcome()
+probastParticipants()
+probastPredictors()
}
Learner:
classDiagram
class Learner
Learner --> SklearnToJson
Learner --> SklearnClassifierSettings
Learner --> SklearnClassifierHelpers
Learner --> SklearnClassifier
Learner --> RClassifier
Learner --> KNN
Learner --> LightGBM
Learner --> GradientBoostingMachine
Learner --> CyclopsModels
Learner --> CyclopsSettings
class SklearnToJson {
+deSerializeAdaboost()
+deSerializeCsrMatrix()
+deSerializeDecisionTree()
+deSerializeMlp()
+deSerializeNaiveBayes()
+deSerializeRandomForest()
+deSerializeSVM()
+deSerializeTree()
+serializeAdaboost()
+serializeCsrMatrix()
+serializeDecisionTree()
+serializeMLP()
+serializeNaiveBayes()
+serializeRandomForest()
+serializeSVM()
+serializeTree()
+sklearnFromJson()
+sklearnToJson()
}
class SklearnClassifierSettings {
+AdaBoostClassifierInputs()
+DecisionTreeClassifierInputs()
+GaussianNBInputs()
+MLPClassifierInputs()
+RandomForestClassifierInputs()
+setAdaBoost()
+setDecisionTree()
+setMLP()
+setNaiveBayes()
+setRandomForest()
+setSVM()
+SVCInputs()
}
class SklearnClassifierHelpers {
+listCartesian()
}
class SklearnClassifier {
+checkPySettings()
+computeGridPerformance()
+fitPythonModel()
+fitSklearn()
+gridCvPython()
+predictPythonSklearn()
+predictValues()
}
class RClassifier {
+applyCrossValidationInR()
+fitRclassifier()
}
class KNN {
+fitKNN()
+predictKnn()
+setKNN()
}
class LightGBM {
+fitLightGBM()
+predictLightGBM()
+setLightGBM()
+varImpLightGBM()
}
class GradientBoostingMachine {
+fitXgboost()
+predictXgboost()
+setGradientBoostingMachine()
+varImpXgboost()
}
class CyclopsModels {
+createCyclopsModel()
+filterCovariateIds()
+fitCyclopsModel()
+getCV()
+getVariableImportance()
+modelTypeToCyclopsModelType()
+predictCyclops()
+predictCyclopsType()
+reparamTransferCoefs()
}
class CyclopsSettings {
+setCoxModel()
+setIterativeHardThresholding()
+setLassoLogisticRegression()
}
Measure:
classDiagram
class Measure
Measure --> ViewShinyPlp
Measure --> uploadToDatabasePerformance
Measure --> uploadToDatabase
Measure --> uploadToDatabaseDiagnostics
Measure --> uploadToDatabaseModelDesign
Measure --> ThresholdSummary
Measure --> PredictionDistribution
Measure --> Plotting
Measure --> CovariateSummary
Measure --> EvaluatePlp
Measure --> EvaluationSummary
Measure --> DemographicSummary
Measure --> CalibrationSummary
Measure --> ImportFromCsv
class ViewShinyPlp {
+viewDatabaseResultPlp()
+viewMultiplePlp()
+viewPlp()
+viewPlps()
}
class uploadToDatabasePerformance {
+addAttrition()
+addCalibrationSummary()
+addCovariateSummary()
+addDemographicSummary()
+addEvaluation()
+addEvaluationStatistics()
+addPerformance()
+addPredictionDistribution()
+addThresholdSummary()
+checkResultExists()
+getColumnNames()
+insertPerformanceInDatabase()
}
class uploadToDatabase {
+addCohort()
+addDatabase()
+addModel()
+addMultipleRunPlpToDatabase()
+addRunPlpToDatabase()
+checkJson()
+checkTable()
+cleanNum()
+createDatabaseList()
+createDatabaseSchemaSettings()
+createPlpResultTables()
+deleteTables()
+enc()
+getCohortDef()
+getPlpResultTables()
+getResultLocations()
+insertModelInDatabase()
+insertResultsToSqlite()
+insertRunPlpToSqlite()
}
class uploadToDatabaseDiagnostics {
+addDiagnosePlpToDatabase()
+addDiagnostic()
+addMultipleDiagnosePlpToDatabase()
+addResultTable()
+insertDiagnosisToDatabase()
}
class uploadToDatabaseModelDesign {
+addCovariateSetting()
+addFESetting()
+addModelDesign()
+addModelSetting()
+addPlpDataSetting()
+addPopulationSetting()
+addSampleSetting()
+addSplitSettings()
+addTar()
+addTidySetting()
+insertModelDesignInDatabase()
+insertModelDesignSettings()
+orderJson()
}
class ThresholdSummary {
+accuracy()
+checkToByTwoTableInputs()
+diagnosticOddsRatio()
+f1Score()
+falseDiscoveryRate()
+falseNegativeRate()
+falseOmissionRate()
+falsePositiveRate()
+getThresholdSummary()
+getThresholdSummary_binary()
+getThresholdSummary_survival()
+negativeLikelihoodRatio()
+negativePredictiveValue()
+positiveLikelihoodRatio()
+positivePredictiveValue()
+sensitivity()
+specificity()
+stdca()
}
class PredictionDistribution {
+getPredictionDistribution()
+getPredictionDistribution_binary()
+getPredictionDistribution_survival()
}
class Plotting {
+outcomeSurvivalPlot()
+plotDemographicSummary()
+plotF1Measure()
+plotGeneralizability()
+plotPlp()
+plotPrecisionRecall()
+plotPredictedPDF()
+plotPredictionDistribution()
+plotPreferencePDF()
+plotSmoothCalibration()
+plotSmoothCalibrationLoess()
+plotSmoothCalibrationRcs()
+plotSparseCalibration()
+plotSparseCalibration2()
+plotSparseRoc()
+plotVariableScatterplot()
}
class CovariateSummary {
+aggregateCovariateSummaries()
+covariateSummary()
+covariateSummarySubset()
+createCovariateSubsets()
+getCovariatesForGroup()
}
class EvaluatePlp {
+evaluatePlp()
+modelBasedConcordance()
}
class EvaluationSummary {
+aucWithCi()
+aucWithoutCi()
+averagePrecision()
+brierScore()
+calculateEStatisticsBinary()
+calibrationInLarge()
+calibrationInLargeIntercept()
+calibrationLine()
+calibrationWeak()
+computeAuc()
+getEvaluationStatistics()
+getEvaluationStatistics_binary()
+getEvaluationStatistics_survival()
+ici()
}
class DemographicSummary {
+getDemographicSummary()
+getDemographicSummary_binary()
+getDemographicSummary_survival()
}
class CalibrationSummary {
+getCalibrationSummary()
+getCalibrationSummary_binary()
+getCalibrationSummary_survival()
}
class ImportFromCsv {
+extractCohortDefinitionsCSV()
+extractDatabaseListCSV()
+extractDiagnosticFromCsv()
+extractObjectFromCsv()
+getModelDesignCsv()
+getModelDesignSettingTable()
+getPerformanceEvaluationCsv()
+getTableNamesPlp()
+insertCsvToDatabase()
}
Prediction:
classDiagram
class Prediction
Prediction --> ExternalValidatePlp
Prediction --> Recalibration
Prediction --> Predict
class ExternalValidatePlp {
+createValidationDesign()
+createValidationSettings()
+externalValidateDbPlp()
+externalValidatePlp()
+validateExternal()
+validateModel()
}
class Recalibration {
+inverseLog()
+logFunct()
+recalibratePlp()
+recalibratePlpRefit()
+recalibrationInTheLarge()
+weakRecalibration()
}
class Predict {
+applyFeatureengineering()
+applyTidyCovariateData()
+predictPlp()
}
For information. TidyModels uses Parsnip to provide model interfaces. They describe their design here:
https://github.com/tidymodels/parsnip/tree/main/R#readme
They seem to be using function calls although it is a bit complicated.
A place to discuss the refactor of PLP and get an overview of the current and options for a prospective code base. Currently the project is file-organized and function-based. Below is a "class" diagram of all files and functions in the R folder.
Related resources: Draft PR for new model API: https://github.com/OHDSI/PatientLevelPrediction/pull/462