OHDSI / PatientLevelPrediction

An R package for performing patient level prediction in an observational database in the OMOP Common Data Model.
https://ohdsi.github.io/PatientLevelPrediction
188 stars 89 forks source link

Add basic imputation for measurement values. #461

Open egillax opened 5 months ago

egillax commented 5 months ago

I think a basic functionality that the package should have built in is basic imputation. Imputing based on mean, median or some other simply calculated single value. This would primarily be useful for measurement values. Would also be nice to have a threshold for including a certain measurement, i.e. if missing values are more than 50% don't include them. Or a threshold based on absolute number of measurements, since otherwise the imputation would be to noisy.

We have today the new age stratified imputation but I think that is maybe more for advanced use cases.

This should be very straightforward to add using the featureEngineering api.

egillax commented 2 weeks ago

See scikit-learn for inspiration: https://scikit-learn.org/1.5/modules/impute.html

egillax commented 1 week ago

@jreps a kind reminder you promised me some existing code for this :)

egillax commented 5 days ago

I think I found the code from: https://github.com/ohdsi-studies/PlpMeasurementFeasability/

getMeasurementCovariateData <- function(connection,
                                        oracleTempSchema = NULL,
                                        cdmDatabaseSchema,
                                        cdmVersion = "5",
                                        cohortTable = "#cohort_person",
                                        rowIdField = "row_id",
                                        aggregated,
                                        cohortId,
                                        covariateSettings) {

  ParallelLogger::logInfo(paste0('Extracting measurement ', covariateSettings$covariateId))

  # Some SQL to construct the covariate:
  sql <- paste("select * from (select c.@row_id_field AS row_id, measurement_concept_id, unit_concept_id,",
               "measurement_date, abs(datediff(dd, measurement_date, c.cohort_start_date)) as index_time,value_as_number raw_value,",
               "row_number() over (partition by @row_id_field  order by measurement_date desc) as rn,",
               "@covariate_id as covariate_id",
               "from @cdm_database_schema.measurement m inner join @cohort_temp_table c on c.subject_id = m.person_id
   and measurement_date >= dateadd(day, @startDay, cohort_start_date) and 
   measurement_date <= dateadd(day, @endDay, cohort_start_date)",
               "inner join @cdm_database_schema.person p on p.person_id=c.subject_id",
               "where m.measurement_concept_id in (@concepts) and m.unit_concept_id in (@units)) temp where rn = 1;"
  )

  sql <- SqlRender::render(sql,
                           cohort_temp_table = cohortTable,
                           row_id_field = rowIdField,
                           startDay=covariateSettings$startDay,
                           endDay=covariateSettings$endDay,
                           concepts = paste(covariateSettings$conceptSet, collapse = ','),
                           units = paste(covariateSettings$conceptUnitSet, collapse = ','),
                           cdm_database_schema = cdmDatabaseSchema,
                           covariate_id = covariateSettings$covariateId
  )
  sql <- SqlRender::translate(sql, targetDialect = attr(connection, "dbms"),
                              oracleTempSchema = oracleTempSchema)
  # Retrieve the covariate:
  covariates <- DatabaseConnector::querySql(connection, sql, integer64AsNumeric = TRUE)
  # Convert colum names to camelCase:
  colnames(covariates) <- SqlRender::snakeCaseToCamelCase(colnames(covariates))

  ParallelLogger::logInfo(paste0('Extracted data'))

  # map data:
  ParallelLogger::logInfo(paste0(sum(is.na(covariates$rawValue)), ' NA values'))
  covariates <- covariates[!is.na(covariates$rawValue),]
  ParallelLogger::logInfo(paste0(nrow(covariates), ' patients with measurement'))
  if(nrow(covariates) > 0 ){
    covariates <- covariateSettings$scaleMap(covariates)
  }

  # impute missing - add age here to be able to input age interaction
  #sql <- paste("select distinct c.@row_id_field AS row_id ",
  #             ", LOG(YEAR(c.cohort_start_date)-p.year_of_birth)  as age",
  #             "from @cohort_temp_table c",
  #             "inner join @cdm_database_schema.person p on p.person_id=c.subject_id")

  #sql <- SqlRender::render(sql, cohort_temp_table = cohortTable,
  #                        row_id_field = rowIdField,
  #                        cdm_database_schema = cdmDatabaseSchema)
  #sql <- SqlRender::translate(sql, targetDialect = attr(connection, "dbms"),
  #                            oracleTempSchema = oracleTempSchema)
  # Retrieve the covariate:
  #ppl <- DatabaseConnector::querySql(connection, sql)
  #colnames(ppl) <- SqlRender::snakeCaseToCamelCase(colnames(ppl))

  #missingPlp <- ppl[!ppl$rowId%in%covariates$rowId,]
  #if(length(missingPlp$rowId)>0){
  #  
  #  if(covariateSettings$lnValue){
  #    covariateSettings$imputationValue <- log(covariateSettings$imputationValue)
  #  }

  #  if(covariateSettings$ageInteraction){
  #    covVal <- missingPlp$age*covariateSettings$imputationValue
  # } else if(covariateSettings$lnAgeInteraction){
  #   covVal <- log(missingPlp$age)*covariateSettings$imputationValue
  # } else{
  #   covVal <- covariateSettings$imputationValue
  # }

  # extraData <- data.frame(rowId = missingPlp$rowId, 
  #                         covariateId = covariateSettings$covariateId, 
  #                         covariateValue = covVal)
  # covariates <- rbind(covariates, extraData[,colnames(covariates)])
  #}

  ParallelLogger::logInfo(paste0('Processed data'))

  # Construct covariate reference:
  covariateRef <- data.frame(covariateId = covariateSettings$covariateId,
                             covariateName = paste('Measurement during day',
                                                   covariateSettings$startDay,
                                                   'through',
                                                   covariateSettings$endDay,
                                                   'days relative to index:',
                                                   covariateSettings$covariateName
                             ),
                             analysisId = covariateSettings$analysisId,
                             conceptId = 0)

  analysisRef <- data.frame(analysisId = covariateSettings$analysisId,
                            analysisName = "measurement covariate",
                            domainId = "measurement covariate",
                            startDay = covariateSettings$startDay,
                            endDay = covariateSettings$endDay,
                            isBinary = "N",
                            missingMeansZero = "Y")

  metaData <- list(sql = sql, call = match.call())
  result <- Andromeda::andromeda(covariates = covariates,
                                 covariateRef = covariateRef,
                                 analysisRef = analysisRef)
  attr(result, "metaData") <- metaData
  class(result) <- "CovariateData"  
  return(result)
}

createMeasurementCovariateSettings <- function(covariateName, conceptSet, conceptUnitSet,
                                               cohortDatabaseSchema, cohortTable, cohortId,
                                               startDay=-30, endDay=0, 
                                               scaleMap = NULL,
                                               imputationValue = 0,
                                               covariateId = 1466,
                                               #measurementId = 1,
                                               analysisId = 466
) {

  covariateSettings <- list(covariateName=covariateName, 
                            conceptSet=conceptSet,
                            conceptUnitSet = conceptUnitSet,
                            startDay=startDay,
                            endDay=endDay,
                            scaleMap=scaleMap,
                            imputationValue = imputationValue,
                            covariateId = covariateId,
                            #measurementId = measurementId, 
                            analysisId = analysisId
  )

  attr(covariateSettings, "fun") <- "getMeasurementCovariateData"
  class(covariateSettings) <- "covariateSettings"
  return(covariateSettings)
}