jpmml / r2pmml

R library for converting R models to PMML
GNU Affero General Public License v3.0
73 stars 18 forks source link

Support Weight and Base_Margin in genDMatrix #61

Closed SixiangHu closed 4 years ago

SixiangHu commented 4 years ago

It is quite common to set weight and base_margin for each of the observation before training xgboost model.

And xgboost DMatrix does support this feature while reading additional .libsvm files. https://xgboost.readthedocs.io/en/latest/tutorials/input_format.html#embedding-additional-information-inside-libsvm-file

Have a checked the current genDMatrix code, and found that both parameters can be added really easily.

The concerns here are:

Is there any further concern why this was not added as a feature?

genDMatrix = function(df_y, df_X, 
                       file = tempfile(pattern = "DMatrix", fileext = ".libsvm"),
                       weight=NULL,
                       base_margin=NULL){
  col2len = function(x){
    col = df_X[[x]]
    if(is.factor(col)){
      return (length(levels(col)))
    }
    return (1)
  }
  col_len = sapply(names(df_X), FUN = col2len)
  col_offset = (cumsum(col_len) - col_len)

  factor2pos = function(x){
    if(is.na(x)){
      return (NA)
    }
    if(is.factor(x)){
      return (as.integer(x) - 1)
    }
    return (0)
  }

  format_cell = function(x){
    if(is.na(x)){
      return (NA)
    }
    if(is.factor(x)){
      return (1)
    }
    return (x)
  }

  fp = file(file, "w")

  # additional code:
  if(!is.null(weight)) {
    file_weight = paste0(file,".weight")
    fw = file(file_weight, "w")
  }
  if(!is.null(base_margin)) {
    file_base_margin = paste0(file,".base_margin")
    fb = file(file_base_margin, "w")
  }

  for(i in 1:nrow(df_X)){
    cell_offset = (col_offset + sapply(df_X[i, ], FUN = factor2pos))
    cell_value = sapply(df_X[i, ], FUN = format_cell)

    y_value = df_y[i]
    X_values = paste(na.omit(cell_offset), na.omit(cell_value), sep = ":", collapse = " ")

    cat(paste(y_value, " ", X_values, "\n", sep = ""), file = fp)

    # additional code:
    if(!is.null(weight)) cat(paste(weight[i], "\n", sep = ""), file = fw)
    if(!is.null(base_margin)) cat(paste(base_margin[i], "\n", sep = ""), file = fb)

  }

  close(fp)

 # additional code:
  if(!is.null(weight)) close(fw)
  if(!is.null(base_margin)) close(fb)

  return (xgboost::xgb.DMatrix(file))
}