salesforce / TransmogrifAI

TransmogrifAI (pronounced trăns-mŏgˈrə-fī) is an AutoML library for building modular, reusable, strongly typed machine learning workflows on Apache Spark with minimal hand-tuning
https://transmogrif.ai
BSD 3-Clause "New" or "Revised" License
2.24k stars 393 forks source link

Number Limit of columns, arguments, variables of input dataset #328

Closed hzrt123 closed 5 years ago

hzrt123 commented 5 years ago

Transmogrif doesn't work when input dataset has more than around 100 columns?

As in the example, we use 'case class' to define the input data structure.

When the columns( i.e. arguments or variables of the class ) is more than around 100, there are errors. In the build stage, it says java.lang.StackOverFlowError (but it passed at last). In the run stage, it says java.lang.ClassFormatError: too many arguments in method signature in class file.

I did some research and found 2 possible reasons:

  1. in Scala language, there is number limit for 'case class'

  2. in JVM, spark, there is number limit for 'RDD', 'dataframe'

Could someone help with this issue ? How do you deal with a large number of columns of data in transmogrif ?

Thank you.

tovbinm commented 5 years ago

What's the input data format is? What are the errors you're getting?

hzrt123 commented 5 years ago

It is .csv file.

object OpTitanicMini {

  case class Passenger
  (
    v1: Double,
    v10: Double,
    v100: Double,
    v101: Double,
    v102: Double,
    v103: Double,
    v104: Double,
    v105: Double,
    v106: Double,
    v107: Double,
    v108: Double,
    v109: Double,
    v11: Double,
    v110: Double,
    v111: Double,
    v112: Double,
    v113: Double,
    v114: Double,
    v115: Double,
    v116: Double,
    v117: Double,
    v118: Double,
    v119: Double,
    v12: Double,
    v120: Double,
    v121: Double,
    v122: Double,
    v123: Double,
    v124: Double,
    v125: Double,
    v126: Double,
    v127: Double,
    v128: Double,
    v129: Double,
    v13: Double,
    v130: Double,
    v131: Double,
    v14: Double,
    v15: Double,
    v16: Double,
    v17: Double,
    v18: Double,
    v19: Double,
    v2: Double,
    v20: Double,
    v21: Double,
    v23: Double,
    v24: Double,
    v25: Double,
    v26: Double,
    v27: Double,
    v28: Double,
    v29: Double,
    v3: Double,
    v30: Double,
    v31: Double,
    v32: Double,
    v33: Double,
    v34: Double,
    v35: Double,
    v36: Double,
    v37: Double,
    v38: Double,
    v39: Double,
    v4: Double,
    v40: Double,
    v41: Double,
    v42: Double,
    v43: Double,
    v44: Double,
    v45: Double,
    v46: Double,
    v47: Double,
    v48: Double,
    v49: Double,
    v5: Double,
    v50: Double,
    v51: Double,
    v52: Double,
    v53: Double,
    v54: Double,
    v55: Double,
    v56: Double,
    v57: Double,
    v58: Double,
    v59: Double,
    v6: Double,
    v60: Double,
    v61: Double,
    v62: Double,
    v63: Double,
    v64: Double,
    v65: Double,
    v66: Double,
    v67: Double,
    v68: Double,
    v69: Double,
    v7: Double,
    v70: Double,
    v71: Double,
    v72: Double,
    v73: Double,
    v74: Double,
    v75: Double,
    v76: Double,
    v77: Double,
    v78: Double,
    v79: Double,
    v8: Double,
    v80: Double,
    v81: Double,
    v82: Double,
    v83: Double,
    v84: Double,
    v85: Double,
    v86: Double,
    v87: Double,
    v88: Double,
    v89: Double,
    v9: Double,
    v90: Double,
    v92: Double,
    v93: Double,
    v94: Double,
    v95: Double,
    v96: Double,
    v97: Double,
    v98: Double,
    v99: Double,
    target: Double
  )

  def main(args: Array[String]): Unit = {
    LogManager.getLogger("com.salesforce.op").setLevel(Level.ERROR)
    implicit val spark = SparkSession.builder.config(new SparkConf()).getOrCreate()
    import spark.implicits._

    // Read Titanic data as a DataFrame
    val pathToData = Option(args(0))
    val passengersData = DataReaders.Simple.csvCase[Passenger](pathToData, key = ReaderKey.randomKey)
      .readDataset().toDF()

    // Automated feature engineering
    val (survived, features) = FeatureBuilder.fromDataFrame[RealNN](passengersData, response = "target")
    val featureVector = features.transmogrify()

    // Automated feature selection
    val checkedFeatures = survived.sanityCheck(featureVector, checkSample = 1.0, removeBadFeatures = true)

    // Automated model selection
    val prediction = BinaryClassificationModelSelector
      .withCrossValidation(modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier))
      .setInput(survived, checkedFeatures).getOutput()
    val model = new OpWorkflow().setInputDataset(passengersData).setResultFeatures(prediction).train()

    println("Model summary:\n" + model.summaryPretty())
  }

}

111 222

tovbinm commented 5 years ago

You can use spark directly bypassing the readers, i.e.

val passengersData: DataFrame = spark.read.csv(pathToData)

This has a maximum of 20480 columns by default, which you can modify with .options - https://spark.apache.org/docs/2.3.3/api/java/org/apache/spark/sql/DataFrameReader.html#csv-scala.collection.Seq-

hzrt123 commented 5 years ago

I tried your solution, but still failed.

  1. the 'case class' still has build error.

  2. the following code after the reader requires 'schema' information, e.g. 'target' feature whereabouts. 555

  3. I checked the source code of datareader, and found the spark.sql.dataframereader is already in use actually. 333

Perhaps you can show me a complete solution ?

Thank you.

tovbinm commented 5 years ago
  1. You don't really need the Passenger or other case class when using DataFrame, since the data is loaded into Row type, which is essentially a map key -> value.
  2. Depending if you have or not the header csv file loading can be done as follows:
    val passengersData: DataFrame = spark.read
    .option("header", "true") // if you have the header in your csv file to make the schema match 
    .option("inferSchema", "true") // allows to infer data types by Spark 
    .csv(pathToData)

    (more options here - https://docs.databricks.com/spark/latest/data-sources/read-csv.html) Next, you need to make sure you correctly point TransmogrifAI to the response column ("target" below) by inspecting the loaded dataframe and renaming the column - https://stackoverflow.com/questions/35592917/renaming-column-names-of-a-dataframe-in-spark-scala

hzrt123 commented 5 years ago

I use the following code, add the schema manually, and it could work.

Thanks a lot.

object OpTitanicMini {

  def main(args: Array[String]): Unit = {
    LogManager.getLogger("com.salesforce.op").setLevel(Level.ERROR)
    implicit val spark = SparkSession.builder.config(new SparkConf()).getOrCreate()
    import spark.implicits._

    // Read Titanic data as a DataFrame
    val pathToData = Option(args(0))
    // val passengersData = DataReaders.Simple.csvCase[Passenger](pathToData, key = ReaderKey.randomKey)
    //   .readDataset().toDF()

    // scalastyle:off
    val schema = new StructType()
      .add("v1",DoubleType,true)
      .add("v10",DoubleType,true)
      .add("v100",DoubleType,true)
      .add("v101",DoubleType,true)
      .add("v102",DoubleType,true)
      .add("v103",DoubleType,true)
      .add("v104",DoubleType,true)
      .add("v105",DoubleType,true)
      .add("v106",DoubleType,true)
      .add("v107",DoubleType,true)
      .add("v108",DoubleType,true)
      .add("v109",DoubleType,true)
      .add("v11",DoubleType,true)
      .add("v110",DoubleType,true)
      .add("v111",DoubleType,true)
      .add("v112",DoubleType,true)
      .add("v113",DoubleType,true)
      .add("v114",DoubleType,true)
      .add("v115",DoubleType,true)
      .add("v116",DoubleType,true)
      .add("v117",DoubleType,true)
      .add("v118",DoubleType,true)
      .add("v119",DoubleType,true)
      .add("v12",DoubleType,true)
      .add("v120",DoubleType,true)
      .add("v121",DoubleType,true)
      .add("v122",DoubleType,true)
      .add("v123",DoubleType,true)
      .add("v124",DoubleType,true)
      .add("v125",DoubleType,true)
      .add("v126",DoubleType,true)
      .add("v127",DoubleType,true)
      .add("v128",DoubleType,true)
      .add("v129",DoubleType,true)
      .add("v13",DoubleType,true)
      .add("v130",DoubleType,true)
      .add("v131",DoubleType,true)
      .add("v14",DoubleType,true)
      .add("v15",DoubleType,true)
      .add("v16",DoubleType,true)
      .add("v17",DoubleType,true)
      .add("v18",DoubleType,true)
      .add("v19",DoubleType,true)
      .add("v2",DoubleType,true)
      .add("v20",DoubleType,true)
      .add("v21",DoubleType,true)
      .add("v23",DoubleType,true)
      .add("v24",DoubleType,true)
      .add("v25",DoubleType,true)
      .add("v26",DoubleType,true)
      .add("v27",DoubleType,true)
      .add("v28",DoubleType,true)
      .add("v29",DoubleType,true)
      .add("v3",DoubleType,true)
      .add("v30",DoubleType,true)
      .add("v31",DoubleType,true)
      .add("v32",DoubleType,true)
      .add("v33",DoubleType,true)
      .add("v34",DoubleType,true)
      .add("v35",DoubleType,true)
      .add("v36",DoubleType,true)
      .add("v37",DoubleType,true)
      .add("v38",DoubleType,true)
      .add("v39",DoubleType,true)
      .add("v4",DoubleType,true)
      .add("v40",DoubleType,true)
      .add("v41",DoubleType,true)
      .add("v42",DoubleType,true)
      .add("v43",DoubleType,true)
      .add("v44",DoubleType,true)
      .add("v45",DoubleType,true)
      .add("v46",DoubleType,true)
      .add("v47",DoubleType,true)
      .add("v48",DoubleType,true)
      .add("v49",DoubleType,true)
      .add("v5",DoubleType,true)
      .add("v50",DoubleType,true)
      .add("v51",DoubleType,true)
      .add("v52",DoubleType,true)
      .add("v53",DoubleType,true)
      .add("v54",DoubleType,true)
      .add("v55",DoubleType,true)
      .add("v56",DoubleType,true)
      .add("v57",DoubleType,true)
      .add("v58",DoubleType,true)
      .add("v59",DoubleType,true)
      .add("v6",DoubleType,true)
      .add("v60",DoubleType,true)
      .add("v61",DoubleType,true)
      .add("v62",DoubleType,true)
      .add("v63",DoubleType,true)
      .add("v64",DoubleType,true)
      .add("v65",DoubleType,true)
      .add("v66",DoubleType,true)
      .add("v67",DoubleType,true)
      .add("v68",DoubleType,true)
      .add("v69",DoubleType,true)
      .add("v7",DoubleType,true)
      .add("v70",DoubleType,true)
      .add("v71",DoubleType,true)
      .add("v72",DoubleType,true)
      .add("v73",DoubleType,true)
      .add("v74",DoubleType,true)
      .add("v75",DoubleType,true)
      .add("v76",DoubleType,true)
      .add("v77",DoubleType,true)
      .add("v78",DoubleType,true)
      .add("v79",DoubleType,true)
      .add("v8",DoubleType,true)
      .add("v80",DoubleType,true)
      .add("v81",DoubleType,true)
      .add("v82",DoubleType,true)
      .add("v83",DoubleType,true)
      .add("v84",DoubleType,true)
      .add("v85",DoubleType,true)
      .add("v86",DoubleType,true)
      .add("v87",DoubleType,true)
      .add("v88",DoubleType,true)
      .add("v89",DoubleType,true)
      .add("v9",DoubleType,true)
      .add("v90",DoubleType,true)
      .add("v92",DoubleType,true)
      .add("v93",DoubleType,true)
      .add("v94",DoubleType,true)
      .add("v95",DoubleType,true)
      .add("v96",DoubleType,true)
      .add("v97",DoubleType,true)
      .add("v98",DoubleType,true)
      .add("v99",DoubleType,true)
      .add("target",DoubleType,false)
    // scalastyle:on

    val passengersData = spark.read
      .option("header", "false") // if you have the header in your csv file to make the schema match
      // .option("inferSchema", "true") // allows to infer data types by Spark
      .schema(schema)
      .csv(pathToData.get)

    // Automated feature engineering
    val (survived, features) = FeatureBuilder.fromDataFrame[RealNN](passengersData, response = "target")
    val featureVector = features.transmogrify()

    // Automated feature selection
    val checkedFeatures = survived.sanityCheck(featureVector, checkSample = 1.0, removeBadFeatures = true)

    // Automated model selection
    val prediction = BinaryClassificationModelSelector
      .withCrossValidation(modelTypesToUse = Seq(OpLogisticRegression, OpRandomForestClassifier))
      .setInput(survived, checkedFeatures).getOutput()
    val model = new OpWorkflow().setInputDataset(passengersData).setResultFeatures(prediction).train()

    println("Model summary:\n" + model.summaryPretty())
  }

}