salesforce / TransmogrifAI

TransmogrifAI (pronounced trăns-mŏgˈrə-fī) is an AutoML library for building modular, reusable, strongly typed machine learning workflows on Apache Spark with minimal hand-tuning
https://transmogrif.ai
BSD 3-Clause "New" or "Revised" License
2.24k stars 393 forks source link

How to apply the trained model to predict new data ? #347

Closed hzrt123 closed 5 years ago

hzrt123 commented 5 years ago

The workflow has train(), score(), evaluate() methods.

Now I just want to apply the trained model to execute prediction to new data.

I notice the relevant original codes as :

    // Pre-compute transformations dag
    val dag = FitStagesUtil.computeDAG(resultFeatures)

    (path: Option[String]) => {
      // Generate the dataframe with raw features
      val rawData: DataFrame = generateRawData()

      // Apply the transformations DAG on raw data
      val transformedData: DataFrame = applyTransformationsDAG(rawData, dag, persistEveryKStages)

      // Save the scores
      val (scores, metrics) = saveScores(
        path = path,
        keepRawFeatures = keepRawFeatures,
        keepIntermediateFeatures = keepIntermediateFeatures,
        transformedData = transformedData,
        persistScores = persistScores,
        evaluator = evaluator,
        metricsPath = metricsPath
      )

Is there a direct way to do that ?

tovbinm commented 5 years ago

You can do computer the scores using score() function just set the new reader, dataset or rdd on the model:

val scores = model.setReader(scoringReader).score()

Or

val scores = model.setInputDataset(rawData).score()

Read more on how to score models here - https://docs.transmogrif.ai/en/stable/developer-guide/index.html#fitted-workflows

Another option is to use TransmogrifAI local library to compute scores without need of Spark but instead on a simple Map. Read more here - https://github.com/salesforce/TransmogrifAI/blob/master/local/README.md

hzrt123 commented 5 years ago

For the Titanic class example, is the score result supposed to be like the following ?:

root
 |-- key: string (nullable = false)
 |-- _c1-_c10-_c11-_c2-_c3-_c4-_c5-_c6-_c7-_c8-_c9_6-stagesApplied_Prediction_00000000001c: map (nullable = true)
 |    |-- key: string
 |    |-- value: double (valueContainsNull = true)

+--------------------+-------------------------------------------------------------------------------------+
|                 key|_c1-_c10-_c11-_c2-_c3-_c4-_c5-_c6-_c7-_c8-_c9_6-stagesApplied_Prediction_00000000001c|
+--------------------+-------------------------------------------------------------------------------------+
|-2504644008430921712|                                                                 [probability_1 ->...|
|-5852748975174411978|                                                                 [probability_1 ->...|
|-1576764987077729938|                                                                 [probability_1 ->...|
|-2998304527981748582|                                                                 [probability_1 ->...|
| 1298890169624068197|                                                                 [probability_1 ->...|
| 1206230684555713516|                                                                 [probability_1 ->...|
| 7012076935222637940|                                                                 [probability_1 ->...|
|-9220081699879933309|                                                                 [probability_1 ->...|
| 7961705637633786122|                                                                 [probability_1 ->...|
| 2503565079179236495|                                                                 [probability_1 ->...|
| 6592721250924103789|                                                                 [probability_1 ->...|
| 2636864085038063269|                                                                 [probability_1 ->...|
| 2476794700498658525|                                                                 [probability_1 ->...|
|-1258763335939758726|                                                                 [probability_1 ->...|
|-8153231848043822856|                                                                 [probability_1 ->...|
| 2302308033454822382|                                                                 [probability_1 ->...|
|-5183210924796450720|                                                                 [probability_1 ->...|
| 9218625762318447948|                                                                 [probability_1 ->...|
|-7716886433147834982|                                                                 [probability_1 ->...|
| 4096046604846163920|                                                                 [probability_1 ->...|
| 4352352368593652281|                                                                 [probability_1 ->...|
|-4470563387051445399|                                                                 [probability_1 ->...|
|-2710768366221124293|                                                                 [probability_1 ->...|
|-4999392867551541057|                                                                 [probability_1 ->...|
| 6471245122003987378|                                                                 [probability_1 ->...|
| 9081420369086491766|                                                                 [probability_1 ->...|
| 4147281237869004710|                                                                 [probability_1 ->...|
| 5155598579162960564|                                                                 [probability_1 ->...|
| 1792074383230337337|                                                                 [probability_1 ->...|
|-7322189762777026848|                                                                 [probability_1 ->...|
|-1898744015554001396|                                                                 [probability_1 ->...|
| 1601908928397917101|                                                                 [probability_1 ->...|
|-8404405256211672351|                                                                 [probability_1 ->...|
|-7362583582356668586|                                                                 [probability_1 ->...|
|-3098164233332925566|                                                                 [probability_1 ->...|
| 5040584896173756458|                                                                 [probability_1 ->...|
| 7077349930756082912|                                                                 [probability_1 ->...|
|-6029724398088970933|                                                                 [probability_1 ->...|
|  512414287833877886|                                                                 [probability_1 ->...|
|-2948024629490843726|                                                                 [probability_1 ->...|
| 5912394566122479631|                                                                 [probability_1 ->...|
| 6138738268747765621|                                                                 [probability_1 ->...|
|-2681920847670866850|                                                                 [probability_1 ->...|
|-7067694234419529396|                                                                 [probability_1 ->...|
| 3835806169849541229|                                                                 [probability_1 ->...|
|-3976106547084038063|                                                                 [probability_1 ->...|
| 2881209293536521185|                                                                 [probability_1 ->...|
| 3994272608213997574|                                                                 [probability_1 ->...|
|  978909686352879516|                                                                 [probability_1 ->...|
|-6258075958269111652|                                                                 [probability_1 ->...|
| -564849509991371861|                                                                 [probability_1 ->...|
|-6844044759569014926|                                                                 [probability_1 ->...|
| 7146155340220303715|                                                                 [probability_1 ->...|
|-5372369076075293221|                                                                 [probability_1 ->...|
|-8717240326829797434|                                                                 [probability_1 ->...|
|-7703982118043809756|                                                                 [probability_1 ->...|
|   59803964026517546|                                                                 [probability_1 ->...|
| 6268292956568294100|                                                                 [probability_1 ->...|
|-5968893085032511204|                                                                 [probability_1 ->...|
| 4034914820497841205|                                                                 [probability_1 ->...|
|-5019169961801905279|                                                                 [probability_1 ->...|
|-5634624718223374317|                                                                 [probability_1 ->...|
| -793647164141572610|                                                                 [probability_1 ->...|
|-1013978272036150333|                                                                 [probability_1 ->...|
|-4938493913525405031|                                                                 [probability_1 ->...|
|  394577560229564006|                                                                 [probability_1 ->...|
| 7043782964526564000|                                                                 [probability_1 ->...|
|-8395479954969206417|                                                                 [probability_1 ->...|
|-2728772298414119735|                                                                 [probability_1 ->...|
|-5210169724871926430|                                                                 [probability_1 ->...|
| 6903458894901186970|                                                                 [probability_1 ->...|
|-4309664287309979939|                                                                 [probability_1 ->...|
|-8530503925364067830|                                                                 [probability_1 ->...|
| 6279663884544155611|                                                                 [probability_1 ->...|
| 7798214158232179816|                                                                 [probability_1 ->...|
|-6690640581782369993|                                                                 [probability_1 ->...|
| 6302773031326335868|                                                                 [probability_1 ->...|
| 4812590055781537903|                                                                 [probability_1 ->...|
|-4505935422339781653|                                                                 [probability_1 ->...|
|-5505905994480902741|                                                                 [probability_1 ->...|
| 5652144454817264607|                                                                 [probability_1 ->...|
| 7398930427777160701|                                                                 [probability_1 ->...|
|-4769045408768907033|                                                                 [probability_1 ->...|
| 4038225727101347304|                                                                 [probability_1 ->...|
| 3947991444243212315|                                                                 [probability_1 ->...|
|-5986580932422158298|                                                                 [probability_1 ->...|
|-4064474180874050429|                                                                 [probability_1 ->...|
| 2325907313491526821|                                                                 [probability_1 ->...|
| 7973862997079830456|                                                                 [probability_1 ->...|
| 6220751715784524755|                                                                 [probability_1 ->...|
| 3064237911437355212|                                                                 [probability_1 ->...|
| 3992056253513634331|                                                                 [probability_1 ->...|
|   29103031830263143|                                                                 [probability_1 ->...|
| 7778101407104649134|                                                                 [probability_1 ->...|
|-3983431972538685135|                                                                 [probability_1 ->...|
|-8024823779650320250|                                                                 [probability_1 ->...|
|-5760547258037190793|                                                                 [probability_1 ->...|
|-2345089031606431577|                                                                 [probability_1 ->...|
|-9144453963305268472|                                                                 [probability_1 ->...|
|-5570205309899276121|                                                                 [probability_1 ->...|
+--------------------+-------------------------------------------------------------------------------------+
only showing top 100 rows
hzrt123 commented 5 years ago

The BostonHouse regression example result is :

root
 |-- key: string (nullable = false)
 |-- _c1-_c10-_c11-_c12-_c13-_c14-_c2-_c3-_c4-_c5-_c6-_c7-_c8-_c9_4-stagesApplied_Prediction_00000000001b: map (nullable = true)
 |    |-- key: string
 |    |-- value: double (valueContainsNull = true)

+--------------------+----------------------------------------------------------------------------------------------------+
|                 key|_c1-_c10-_c11-_c12-_c13-_c14-_c2-_c3-_c4-_c5-_c6-_c7-_c8-_c9_4-stagesApplied_Prediction_00000000001b|
+--------------------+----------------------------------------------------------------------------------------------------+
|  633710364778013375|                                                                                [prediction -> 27...|
| 1200380298067033798|                                                                                [prediction -> 22...|
| 8047837604097915986|                                                                                [prediction -> 28...|
|-4118144935784069269|                                                                                [prediction -> 29...|
|-5970167395710358811|                                                                                [prediction -> 28...|
|-3093987966044581729|                                                                                [prediction -> 26...|
|  -47273471128075512|                                                                                [prediction -> 21...|
| 7101064316682932118|                                                                                [prediction -> 20...|
|  560975998538274330|                                                                                [prediction -> 17...|
|  218534349360281360|                                                                                [prediction -> 18...|
|-6952414782331059368|                                                                                [prediction -> 18...|
|-5358487499809973758|                                                                                [prediction -> 19...|
|-8409156428111309547|                                                                                [prediction -> 20...|
|-7104722277960854412|                                                                                [prediction -> 19...|
| 7025007749301169738|                                                                                [prediction -> 17...|
| 9160061661233454764|                                                                                [prediction -> 19...|
| 5907716199508593745|                                                                                [prediction -> 20...|
|-6224242743367192948|                                                                                [prediction -> 17...|
| 1503766638988730605|                                                                                [prediction -> 18...|
| 5260121157813488440|                                                                                [prediction -> 18...|
| 6016618760443332749|                                                                                [prediction -> 15...|
| 2988365267940538107|                                                                                [prediction -> 16...|
|-4403961218693472047|                                                                                [prediction -> 16...|
|-5263942831760387221|                                                                                [prediction -> 16...|
|-8448915853358218419|                                                                                [prediction -> 16...|
| 2470302282764886986|                                                                                [prediction -> 16...|
|-4922152924916520820|                                                                                [prediction -> 16...|
| 5171410514645651887|                                                                                [prediction -> 15...|
|-7652865636338780329|                                                                                [prediction -> 18...|
| 8281947226231871461|                                                                                [prediction -> 20...|
| 6987688017127408694|                                                                                [prediction -> 15...|
|-3064939763291998318|                                                                                [prediction -> 15...|
| 1634111235922019638|                                                                                [prediction -> 16...|
| 6427970600582798235|                                                                                [prediction -> 15...|
| 3057881681555373548|                                                                                [prediction -> 15...|
| -924642877265441209|                                                                                [prediction -> 20...|
| 4714855037520766906|                                                                                [prediction -> 20...|
| 5636595814817260834|                                                                                [prediction -> 21...|
|-7343617481906087344|                                                                                [prediction -> 20...|
| 4846959360422233891|                                                                                [prediction -> 28...|
|-8015296657024239499|                                                                                [prediction -> 28...|
| 5842963111518007568|                                                                                [prediction -> 27...|
| -622176671756458745|                                                                                [prediction -> 24...|
|  248615583395961889|                                                                                [prediction -> 24...|
|-8061884003959962560|                                                                                [prediction -> 22...|
|-3607460880363984157|                                                                                [prediction -> 21...|
|-5656485718008296134|                                                                                [prediction -> 20...|
| 6310384802265858143|                                                                                [prediction -> 18...|
|  510405844316254132|                                                                                [prediction -> 18...|
| 1081065311708996987|                                                                                [prediction -> 19...|
| 4227966093107859990|                                                                                [prediction -> 21...|
| 5726676337669578911|                                                                                [prediction -> 22...|
|-7536679240885589070|                                                                                [prediction -> 25...|
|-8326699477339331686|                                                                                [prediction -> 22...|
|  494720351278038860|                                                                                [prediction -> 20...|
| 4553315656430449607|                                                                                [prediction -> 29...|
|  869643542178937308|                                                                                [prediction -> 24...|
| 9150524935690563691|                                                                                [prediction -> 28...|
| 1906139050342611345|                                                                                [prediction -> 23...|
|-4306002623604319391|                                                                                [prediction -> 21...|
| 3252155366847922457|                                                                                [prediction -> 20...|
|-3408900172668860424|                                                                                [prediction -> 17...|
| 4464070171908004639|                                                                                [prediction -> 24...|
|-4765338626042473444|                                                                                [prediction -> 24...|
|-1629807534293224939|                                                                                [prediction -> 28...|
|-1850004427180683341|                                                                                [prediction -> 24...|
| 9040713933854793778|                                                                                [prediction -> 21...|
|-7969957096838089143|                                                                                [prediction -> 22...|
|  860070706899975102|                                                                                [prediction -> 20...|
|-4380187397077229842|                                                                                [prediction -> 21...|
| 1420884256361997197|                                                                                [prediction -> 23...|
| 5704197988173678174|                                                                                [prediction -> 20...|
|-9126621269980966968|                                                                                [prediction -> 23...|
|-5150513394026162334|                                                                                [prediction -> 21...|
| 4738877697603840755|                                                                                [prediction -> 23...|
|-5790080731319397002|                                                                                [prediction -> 22...|
| 7268160249279706922|                                                                                [prediction -> 20...|
| 2906458397611462060|                                                                                [prediction -> 21...|
+--------------------+----------------------------------------------------------------------------------------------------+
hzrt123 commented 5 years ago

I also tried the code:

// Load the trained model
val model = OpWorkflowModel.load("/path/to/model")

// Create score function once and use it indefinitely
val scoreFn = model.scoreFunction

// Spark Session can be stopped now since it's not required during local scoring
spark.stop()

// Compute scores with score function
val rawData = Seq(Map("name" -> "Peter", "age" -> 18), Map("name" -> "John", "age" -> 23))
val scores = rawData.map(scoreFn)

But val model = OpWorkflowModel.load("/path/to/model") seems not have load() method. I don't know how to fix it.

I don't know either the meaning of the manipulation val rawData = Seq(Map("name" -> "Peter", "age" -> 18), Map("name" -> "John", "age" -> 23)), without dataset context information.

Do you have further more context information about that? I would appreciate it very much. Thanks.

tovbinm commented 5 years ago

The above model output looks correct. You have two columns representing a key and the prediction (the key is defined on your reader using the keyFn).

  1. What's the error on OpWorkflowModel.load?

  2. For local scoring, each record (represented by Map(...)) need to contain the keys and the values of your features so the model can computer your score. In the above example raw data had two raw features name and age.

hzrt123 commented 5 years ago

OpWorkflowModel couldn't find load() method.

hzrt123 commented 5 years ago

2019-07-03 11-22-56╞┴─╗╜╪═╝

tovbinm commented 5 years ago

Oh right, this method wasn’t released yet and only available from the master branch.

For the latest released version 0.5.x you would need to call old method on the workflow, I.e:

val model = workflow.load(path)

tovbinm commented 5 years ago

Please review the documentation - https://docs.transmogrif.ai/en/stable/developer-guide/index.html#loading-saved-workflows

hzrt123 commented 5 years ago

I test the Titanic example :

    val model = new OpWorkflow().setInputDataset(passengersData).setResultFeatures(prediction).train()

    val scoreFn = model.scoreFunction

    val rawData =
      Seq(Map("_c2" -> 3, "_c3" -> "Braund, Mr. Owen", "_c4" -> "male", "_c5" -> 22.0, "_c6" -> 1
      , "_c7" -> 0, "_c8" -> "A/5 21171", "_c9" -> 7.25, "_c10" -> null, "_c11" -> "S")
      , Map("_c2" -> 1, "_c3" -> "Cumings, Mrs. Joh", "_c4" -> "female", "_c5" -> 38.0, "_c6" -> 1
        , "_c7" -> 0, "_c8" -> "PC 17599", "_c9" -> 71.2833, "_c10" -> "C85", "_c11" -> "C"))

    val scores_predict = rawData.map(scoreFn)

    print(scores_predict)

The dataframe is this:

+---+---+--------------------+------+----+---+---+----------------+--------+-----------+----+
|_c1|_c2|                 _c3|   _c4| _c5|_c6|_c7|             _c8|     _c9|       _c10|_c11|
+---+---+--------------------+------+----+---+---+----------------+--------+-----------+----+
|0.0|  3|Braund, Mr. Owen ...|  male|22.0|  1|  0|       A/5 21171|    7.25|       null|   S|
|1.0|  1|Cumings, Mrs. Joh...|female|38.0|  1|  0|        PC 17599| 71.2833|        C85|   C|
|1.0|  3|Heikkinen, Miss. ...|female|26.0|  0|  0|STON/O2. 3101282|   7.925|       null|   S|
|1.0|  1|Futrelle, Mrs. Ja...|female|35.0|  1|  0|          113803|    53.1|       C123|   S|
|0.0|  3|Allen, Mr. Willia...|  male|35.0|  0|  0|          373450|    8.05|       null|   S|
|0.0|  3|    Moran, Mr. James|  male|null|  0|  0|          330877|  8.4583|       null|   Q|
|0.0|  1|McCarthy, Mr. Tim...|  male|54.0|  0|  0|           17463| 51.8625|        E46|   S|
|0.0|  3|Palsson, Master. ...|  male| 2.0|  3|  1|          349909|  21.075|       null|   S|
|1.0|  3|Johnson, Mrs. Osc...|female|27.0|  0|  2|          347742| 11.1333|       null|   S|
|1.0|  2|Nasser, Mrs. Nich...|female|14.0|  1|  0|          237736| 30.0708|       null|   C|
|1.0|  3|Sandstrom, Miss. ...|female| 4.0|  1|  1|         PP 9549|    16.7|         G6|   S|
|1.0|  1|Bonnell, Miss. El...|female|58.0|  0|  0|          113783|   26.55|       C103|   S|
|0.0|  3|Saundercock, Mr. ...|  male|20.0|  0|  0|       A/5. 2151|    8.05|       null|   S|
|0.0|  3|Andersson, Mr. An...|  male|39.0|  1|  5|          347082|  31.275|       null|   S|

The output is error:

Exception in thread "main" java.util.NoSuchElementException: key not found: _c1
        at scala.collection.MapLike$class.default(MapLike.scala:228)
        at scala.collection.AbstractMap.default(Map.scala:59)
        at scala.collection.mutable.HashMap.apply(HashMap.scala:65)
        at com.salesforce.op.local.OpWorkflowModelLocal$RichOpWorkflowModel$$anonfun$scoreFunction$1$$anonfun$7$$anonfun$apply$1.apply(OpWorkflowModelLocal.scala:122)
        at com.salesforce.op.local.OpWorkflowModelLocal$RichOpWorkflowModel$$anonfun$scoreFunction$1$$anonfun$7$$anonfun$apply$1.apply(OpWorkflowModelLocal.scala:122)
        at com.salesforce.op.stages.base.binary.OpTransformer2$$anonfun$transformKeyValue$1.apply(BinaryTransformer.scala:88)
        at com.salesforce.op.stages.base.binary.OpTransformer2$$anonfun$transformKeyValue$1.apply(BinaryTransformer.scala:88)
        at com.salesforce.op.local.OpWorkflowModelLocal$RichOpWorkflowModel$$anonfun$scoreFunction$1$$anonfun$7.apply(OpWorkflowModelLocal.scala:122)
        at com.salesforce.op.local.OpWorkflowModelLocal$RichOpWorkflowModel$$anonfun$scoreFunction$1$$anonfun$7.apply(OpWorkflowModelLocal.scala:119)
        at scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57)
        at scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66)
        at scala.collection.mutable.ArrayOps$ofRef.foldLeft(ArrayOps.scala:186)
        at com.salesforce.op.local.OpWorkflowModelLocal$RichOpWorkflowModel$$anonfun$scoreFunction$1.apply(OpWorkflowModelLocal.scala:119)
        at com.salesforce.op.local.OpWorkflowModelLocal$RichOpWorkflowModel$$anonfun$scoreFunction$1.apply(OpWorkflowModelLocal.scala:117)
        at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at scala.collection.immutable.List.foreach(List.scala:381)
        at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
        at scala.collection.immutable.List.map(List.scala:285)
        at com.salesforce.hw.OpTitanicMini$.main(test_data.scala:210)
        at com.salesforce.hw.OpTitanicMini.main(test_data.scala)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:498)
        at org.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52)
        at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:894)
        at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:198)
        at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:228)
        at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:137)
        at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)
tovbinm commented 5 years ago

Just add a key field to each of your raw data items, e.g.:

val rawDataWithKey = rawData.map(m => m + ("key" -> util.Random.nextLong))
val scores_predict = rawDataWithKey.map(scoreFn)
hzrt123 commented 5 years ago

OK, I got it. Thanks a lot.

DineshRajanT commented 4 years ago

Oh right, this method wasn’t released yet and only available from the master branch.

For the latest released version 0.5.x you would need to call old method on the workflow, I.e:

val model = workflow.load(path)

I have tried the above method too.....but still it is throwing me out errors.....