Closed brandongreenwell-8451 closed 1 year ago
Hey @brandongreenwell-8451 :wave:! Thank you so much for reporting the issue/feature request :rotating_light:. Someone from SynapseML Team will be looking to triage this issue soon. We appreciate your patience.
asked Jason to take a look
ack. I'll take a look shortly.
Hi @brandongreenwell-8451, I had some trouble running your code in Synapse, and I don't have access to Databricks. But reading your code, I think the problem is here
onnx_ml = (
onnx_ml.setDeviceType("CPU")
.setFeedDict(onnx_ml.getModelInputs()) # not sure if this is the correct way?!
.setFetchDict(onnx_ml.getModelOutputs()) # not sure if this is the correct way?
)
The setFeedDict method sets up the mapping between the ONNX input nodes and the columns of the input Spark dataframe (X_test_sdf), and the setFetchDict method sets up the mapping between the ONNX output nodes and the columns of the output Spark dataframe.
The setFeedDict method expects a dict with keys being the ONNX input node names, and the values being the corresponding column name from the input Spark dataframe. The setFetchDict method expects a dict with keys being the column name from the output Spark dataframe, and values being the corresponding ONNX output node name.
For example, if the ONNX model's input node name is "input", and output nodes with names "probabilities" and "label", then we can setup the model with the following:
onnx_ml = (
onnx_ml.setDeviceType("CPU")
.setFeedDict({"input": "features"})
.setFetchDict({"probability": "probabilities", "prediction": "label"})
.setMiniBatchSize(5000)
)
If you still have trouble, please show the output from these two lines:
print("Model inputs:" + str(onnx_ml.getModelInputs()))
print("Model outputs:" + str(onnx_ml.getModelOutputs()))
and also show the schema of X_test_sdf
by:
X_test_sdf.printSchema()
Hi @memoryz, thank you for taking a look. I am still having trouble, however, so here's the output you requested:
print("Model inputs:" + str(onnx_ml.getModelInputs()))
Model inputs:{'HoursPerWeek': NodeInfo(name=HoursPerWeek,info=TensorInfo(shape=[-1], type=INT64)), 'CapitalGain': NodeInfo(name=CapitalGain,info=TensorInfo(shape=[-1], type=INT64)), 'Education': NodeInfo(name=Education,info=TensorInfo(shape=[-1], type=STRING)), 'WorkClass': NodeInfo(name=WorkClass,info=TensorInfo(shape=[-1], type=STRING)), 'Race': NodeInfo(name=Race,info=TensorInfo(shape=[-1], type=STRING)), 'Age': NodeInfo(name=Age,info=TensorInfo(shape=[-1], type=INT64)), 'Relationship': NodeInfo(name=Relationship,info=TensorInfo(shape=[-1], type=STRING)), 'EducationNum': NodeInfo(name=EducationNum,info=TensorInfo(shape=[-1], type=INT64)), 'NativeCountry': NodeInfo(name=NativeCountry,info=TensorInfo(shape=[-1], type=STRING)), 'CapitalLoss': NodeInfo(name=CapitalLoss,info=TensorInfo(shape=[-1], type=INT64)), 'Gender': NodeInfo(name=Gender,info=TensorInfo(shape=[-1], type=STRING)), 'MaritalStatus': NodeInfo(name=MaritalStatus,info=TensorInfo(shape=[-1], type=STRING)), 'Occupation': NodeInfo(name=Occupation,info=TensorInfo(shape=[-1], type=STRING)), 'fnlwgt': NodeInfo(name=fnlwgt,info=TensorInfo(shape=[-1], type=INT64))}
print("Model outputs:" + str(onnx_ml.getModelOutputs()))
Model outputs:{'predict_proba_0': NodeInfo(name=predict_proba_0,info=TensorInfo(shape=[-1,2], type=FLOAT))}
print(X_test_sdf.printSchema())
root
|-- Age: long (nullable = true)
|-- WorkClass: string (nullable = true)
|-- fnlwgt: long (nullable = true)
|-- Education: string (nullable = true)
|-- EducationNum: long (nullable = true)
|-- MaritalStatus: string (nullable = true)
|-- Occupation: string (nullable = true)
|-- Relationship: string (nullable = true)
|-- Race: string (nullable = true)
|-- Gender: string (nullable = true)
|-- CapitalGain: long (nullable = true)
|-- CapitalLoss: long (nullable = true)
|-- HoursPerWeek: long (nullable = true)
|-- NativeCountry: string (nullable = true)
None
Based on your comment, I did change these parameters to
onnx_ml = (
onnx_ml.setDeviceType("CPU")
.setFeedDict(dict(zip(X_test_sdf.columns, X_test_sdf.columns))) # keys/values match column names
.setFetchDict({'predict_proba_0': 'predict_proba_0'})
.setMiniBatchSize(5000)
)
And the .transform()
call still failed, but not right away this time. Here's the new error:
IllegalArgumentException: **Image batch is not a sequence**
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 2.0 failed 4 times, most recent failure: Lost task 0.3 in stage 2.0 (TID 11) (10.2.64.8 executor 0): java.lang.IllegalArgumentException: Image batch is not a sequence
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1(ONNXUtils.scala:130)
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1$adapted(ONNXUtils.scala:128)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.validateBatchShapes(ONNXUtils.scala:128)
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.createTensor(ONNXUtils.scala:104)
at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$2(ONNXRuntime.scala:74)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
at scala.collection.IterableLike.foreach(IterableLike.scala:74)
at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$1(ONNXRuntime.scala:67)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
at com.microsoft.azure.synapse.ml.core.utils.CloseableIterator.next(CloseableIterator.scala:11)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:155)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:156)
at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:125)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.Task.run(Task.scala:95)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:832)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1681)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:835)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:690)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:750)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3029)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2976)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2970)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2970)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1390)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1390)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1390)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3238)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3179)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3167)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1152)
at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2657)
at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:266)
at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:276)
at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:81)
at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:87)
at org.apache.spark.sql.execution.collect.InternalRowFormat$.collect(cachedSparkResults.scala:75)
at org.apache.spark.sql.execution.collect.InternalRowFormat$.collect(cachedSparkResults.scala:62)
at org.apache.spark.sql.execution.ResultCacheManager.collectResult$1(ResultCacheManager.scala:573)
at org.apache.spark.sql.execution.ResultCacheManager.computeResult(ResultCacheManager.scala:582)
at org.apache.spark.sql.execution.ResultCacheManager.$anonfun$getOrComputeResultInternal$1(ResultCacheManager.scala:528)
at scala.Option.getOrElse(Option.scala:189)
at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResultInternal(ResultCacheManager.scala:527)
at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:424)
at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:403)
at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:424)
at org.apache.spark.sql.Dataset.collectResult(Dataset.scala:3153)
at org.apache.spark.sql.Dataset.$anonfun$collectResult$1(Dataset.scala:3144)
at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3951)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withCustomExecutionEnv$8(SQLExecution.scala:239)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:386)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withCustomExecutionEnv$1(SQLExecution.scala:186)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:968)
at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:141)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:336)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3949)
at org.apache.spark.sql.Dataset.collectResult(Dataset.scala:3143)
at com.databricks.backend.daemon.driver.OutputAggregator$.withOutputAggregation0(OutputAggregator.scala:266)
at com.databricks.backend.daemon.driver.OutputAggregator$.withOutputAggregation(OutputAggregator.scala:100)
at com.databricks.backend.daemon.driver.PythonDriverLocalBase.generateTableResult(PythonDriverLocalBase.scala:723)
at com.databricks.backend.daemon.driver.PythonDriverLocal.computeListResultsItem(PythonDriverLocal.scala:627)
at com.databricks.backend.daemon.driver.PythonDriverLocalBase.genListResults(PythonDriverLocalBase.scala:630)
at com.databricks.backend.daemon.driver.PythonDriverLocal.$anonfun$getResultBufferInternal$1(PythonDriverLocal.scala:682)
at com.databricks.backend.daemon.driver.PythonDriverLocal.withInterpLock(PythonDriverLocal.scala:563)
at com.databricks.backend.daemon.driver.PythonDriverLocal.getResultBufferInternal(PythonDriverLocal.scala:642)
at com.databricks.backend.daemon.driver.DriverLocal.getResultBuffer(DriverLocal.scala:748)
at com.databricks.backend.daemon.driver.PythonDriverLocal.outputSuccess(PythonDriverLocal.scala:605)
at com.databricks.backend.daemon.driver.PythonDriverLocal.$anonfun$repl$6(PythonDriverLocal.scala:223)
at com.databricks.backend.daemon.driver.PythonDriverLocal.withInterpLock(PythonDriverLocal.scala:563)
at com.databricks.backend.daemon.driver.PythonDriverLocal.repl(PythonDriverLocal.scala:210)
at com.databricks.backend.daemon.driver.DriverLocal.$anonfun$execute$13(DriverLocal.scala:634)
at com.databricks.logging.Log4jUsageLoggingShim$.$anonfun$withAttributionContext$1(Log4jUsageLoggingShim.scala:33)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
at com.databricks.logging.AttributionContext$.withValue(AttributionContext.scala:94)
at com.databricks.logging.Log4jUsageLoggingShim$.withAttributionContext(Log4jUsageLoggingShim.scala:31)
at com.databricks.logging.UsageLogging.withAttributionContext(UsageLogging.scala:205)
at com.databricks.logging.UsageLogging.withAttributionContext$(UsageLogging.scala:204)
at com.databricks.backend.daemon.driver.DriverLocal.withAttributionContext(DriverLocal.scala:59)
at com.databricks.logging.UsageLogging.withAttributionTags(UsageLogging.scala:240)
at com.databricks.logging.UsageLogging.withAttributionTags$(UsageLogging.scala:225)
at com.databricks.backend.daemon.driver.DriverLocal.withAttributionTags(DriverLocal.scala:59)
at com.databricks.backend.daemon.driver.DriverLocal.execute(DriverLocal.scala:611)
at com.databricks.backend.daemon.driver.DriverWrapper.$anonfun$tryExecutingCommand$1(DriverWrapper.scala:615)
at scala.util.Try$.apply(Try.scala:213)
at com.databricks.backend.daemon.driver.DriverWrapper.tryExecutingCommand(DriverWrapper.scala:607)
at com.databricks.backend.daemon.driver.DriverWrapper.executeCommandAndGetError(DriverWrapper.scala:526)
at com.databricks.backend.daemon.driver.DriverWrapper.executeCommand(DriverWrapper.scala:561)
at com.databricks.backend.daemon.driver.DriverWrapper.runInnerLoop(DriverWrapper.scala:431)
at com.databricks.backend.daemon.driver.DriverWrapper.runInner(DriverWrapper.scala:374)
at com.databricks.backend.daemon.driver.DriverWrapper.run(DriverWrapper.scala:225)
at java.lang.Thread.run(Thread.java:750)
Caused by: java.lang.IllegalArgumentException: Image batch is not a sequence
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1(ONNXUtils.scala:130)
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1$adapted(ONNXUtils.scala:128)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.validateBatchShapes(ONNXUtils.scala:128)
at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.createTensor(ONNXUtils.scala:104)
at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$2(ONNXRuntime.scala:74)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
at scala.collection.IterableLike.foreach(IterableLike.scala:74)
at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$1(ONNXRuntime.scala:67)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
at com.microsoft.azure.synapse.ml.core.utils.CloseableIterator.next(CloseableIterator.scala:11)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:155)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:156)
at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:125)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.Task.run(Task.scala:95)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:832)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1681)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:835)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:690)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
Hello @brandongreenwell-8451 the output shape of EBM with probabilities is [-1, 2]. I do not know how SynapseML maps a multi-dimension output to a column. Maybe this is the cause of the issue. Can you serialize the model without predict_proba to test if it helps. The output will only be the predicted class instead of the proba for each class (the name of the output will be different).
Hi @MainRo, same error when serializing without the predict_proba
option and updating the .setFetchDict
argument accordingly.
print("Model outputs:" + str(onnx_ml.getModelOutputs()))
Model outputs:{'predict_0': NodeInfo(name=predict_0,info=TensorInfo(shape=[-1], type=INT64))}
Since this is a binary outcome, the goal is to be able to predict probabilities, as opposed to class labels.
@brandongreenwell-8451 I can finally reproduce the issue and have a PR #1906 sent out.
@memoryz this is great to hear! Will this also fix the case for predict_proba=True
, which gives shape [-1,2]?
The case for predict_proba=True
is not impacted by this bug or the bug fix. When the output shape is [-1, 2], the output column will contain arrays of length 2. To illustrate the usage:
val onnxGH1902 = new ONNXModel()
.setModelLocation(model.getPath)
.setDeviceType("CPU")
.setFeedDict(featuresAdultsIncome.map(v => (v, v)).toMap)
.setFetchDict(Map("probability" -> "probabilities"))
.setArgMaxDict(Map("probability" -> "prediction"))
.setMiniBatchSize(5000)
val Array(row1, row2) = onnxGH1902.transform(testDfGH1902)
.select("probability", "prediction")
.orderBy(col("prediction"))
.as[(Seq[Double], Double)]
.collect()
assert(row1._1 === Seq(0.9343283176422119, 0.0656716451048851))
assert(row1._2 === 0.0)
assert(row2._1 === Seq(0.16954122483730316, 0.8304587006568909))
assert(row2._2 === 1.0)
@brandongreenwell-8451 to test the fix, please use this coordinate: com.microsoft.azure:synapseml_2.12:0.11.0-36-86ce42ba-SNAPSHOT
with the resolver: https://mmlspark.azureedge.net/maven
. This is from the PR build.
Huzzah!! @memoryz, I am indeed getting probabilities now! This is great, and our team truly appreciates the effort!
@memoryz , is it possible to use a package with this fix in Python? I think I hit the same issue where inputs of my model are:
Model inputs: {
'attention_mask': NodeInfo(name=attention_mask, info=TensorInfo(shape=[-1,-1], type=INT64)),
'input_ids': NodeInfo(name=input_ids, info=TensorInfo(shape=[-1,-1], type=INT64)),
'token_type_ids': NodeInfo(name=token_type_ids, info=TensorInfo(shape=[-1,-1], type=INT64))
}
Alternatively, is it possible to modify the shape via synapse.ml.onnx
syntax as I think that my shape is in reality fixed to 512?
@fralik my original issue was in Python (see the original reproducible example), but this seemingly fixed it for me?
@memoryz , is it possible to use a package with this fix in Python? I think I hit the same issue where inputs of my model are:
Model inputs: { 'attention_mask': NodeInfo(name=attention_mask, info=TensorInfo(shape=[-1,-1], type=INT64)), 'input_ids': NodeInfo(name=input_ids, info=TensorInfo(shape=[-1,-1], type=INT64)), 'token_type_ids': NodeInfo(name=token_type_ids, info=TensorInfo(shape=[-1,-1], type=INT64)) }
Alternatively, is it possible to modify the shape via
synapse.ml.onnx
syntax as I think that my shape is in reality fixed to 512?
Which version did you try? Please use this coordinate: com.microsoft.azure:synapseml_2.12:0.11.0-36-86ce42ba-SNAPSHOT
with the resolver: https://mmlspark.azureedge.net/maven
. The fix is in the master branch, but not released yet.
Shape modification is not supported.
I just created a synapse cluster with Spark 3.3 in Azure. Here is the list of installed packages with synapse
in them:
azure-synapse-ml-predict==1.0.0
azureml-synapse==0.0.1
synapseml-cognitive==0.10.1.dev1
synapseml-core==0.10.1.dev1
synapseml-deep-learning==0.10.1.dev1
synapseml-internal==0.0.0.dev1
synapseml-lightgbm==0.10.1.dev1
synapseml-opencv==0.10.1.dev1
synapseml-vw==0.10.1.dev1
@fralik Can you please try the following configuration for Synapse 3.2 pool (until we release the next version), by placing the following in the first cell of your notebook:
%%configure -f
{
"name": "synapseml",
"conf": {
"spark.jars.packages": "com.microsoft.azure:synapseml_2.12:0.11.0-36-86ce42ba-SNAPSHOT,org.apache.spark:spark-avro_2.12:3.3.1",
"spark.jars.repositories": "https://mmlspark.azureedge.net/maven",
"spark.jars.excludes": "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalactic:scalactic_2.12,org.scalatest:scalatest_2.12,com.fasterxml.jackson.core:jackson-databind",
"spark.yarn.user.classpath.first": "true",
"spark.sql.parquet.enableVectorizedReader": "false",
"spark.sql.legacy.replaceDatabricksSparkAvro.enabled": "true"
}
}
@memoryz , thanks so much! It did work with this version.
When do you plan to release the next version? I am actually using Substrate in O365 where Synapse and Spark pools are provisioned for us. Usage of Azure Synapse was only for a quick proof of concept, but customer data can only be handled via Substrate. In Substrate, I won't be able to configure the pool myself and would need to ask for onboarding of the new version.
@mhamilton723, when will we release the next version and get it included into Synapse runtime?
@fralik, AFAIK, it takes some time to have the new release included into Synapse runtime after the new release happens. Meanwhile, the %%configure
cell can be added to the notebook, so it doesn't require you to modify the pool directly. Would this work for you for now?
Thanks @memoryz . I think I am good for now.
Thanks for confirmation. Closing.
SynapseML version
0.11.0
System information
Databricks runtime version: 10.4 LTS (includes Apache Spark 3.2.1, Scala 2.12)
Python version: 3.8.10
Describe the problem
Trying to get Spark inference working for an EBM model converted to ONNX via the interpret and ebm2onnx packages. Code to reproduce the example given below. Final error when calling
.transform()
producesCode to reproduce issue
Other info / logs
What component(s) does this bug affect?
area/cognitive
: Cognitive projectarea/core
: Core projectarea/deep-learning
: DeepLearning projectarea/lightgbm
: Lightgbm projectarea/opencv
: Opencv projectarea/vw
: VW projectarea/website
: Websitearea/build
: Project build systemarea/notebooks
: Samples under notebooks folderarea/docker
: Docker usagearea/models
: models related issueWhat language(s) does this bug affect?
language/scala
: Scala source codelanguage/python
: Pyspark APIslanguage/r
: R APIslanguage/csharp
: .NET APIslanguage/new
: Proposals for new client languagesWhat integration(s) does this bug affect?
integrations/synapse
: Azure Synapse integrationsintegrations/azureml
: Azure ML integrationsintegrations/databricks
: Databricks integrations