microsoft / SynapseML

Simple and Distributed Machine Learning
http://aka.ms/spark
MIT License
5.07k stars 831 forks source link

Support for EBM/ONNX Model (Possible Bug?) [BUG] [HELP] #1902

Closed brandongreenwell-8451 closed 1 year ago

brandongreenwell-8451 commented 1 year ago

SynapseML version

0.11.0

System information

Databricks runtime version: 10.4 LTS (includes Apache Spark 3.2.1, Scala 2.12)

Python version: 3.8.10

Describe the problem

Trying to get Spark inference working for an EBM model converted to ONNX via the interpret and ebm2onnx packages. Code to reproduce the example given below. Final error when calling .transform() produces

java.lang.ClassCastException: net.razorvine.pickle.objects.ClassDict cannot be cast to java.lang.String

Code to reproduce issue

import numpy as np
import pandas as pd
import onnx
import ebm2onnx

from sklearn.model_selection import train_test_split
from interpret.glassbox import ExplainableBoostingClassifier
from synapse.ml.onnx import ONNXModel

# Read in adult data from UCI ML repo
df = pd.read_csv(
    "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
    header=None)
df.columns = [
    "Age", "WorkClass", "fnlwgt", "Education", "EducationNum",
    "MaritalStatus", "Occupation", "Relationship", "Race", "Gender",
    "CapitalGain", "CapitalLoss", "HoursPerWeek", "NativeCountry", "Income"
]

# Sample data and split into train/test sets
seed = 42
np.random.seed(seed)
df = df.sample(frac=0.05, random_state=seed)
train_cols = df.columns[0:-1]
label = df.columns[-1]
X = df[train_cols]
y = df[label]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)

# Fit a (default) EBM model
ebm = ExplainableBoostingClassifier()
ebm.fit(X_train, y_train)

def convert_model(model, input_df):
    onnx_model = ebm2onnx.to_onnx(
        model,
        ebm2onnx.get_dtype_from_pandas(input_df),
        predict_proba=True
    )
    onnx_model.ir_version = 4 
    return onnx_model.SerializeToString()

# Load ONNX payload into an ONNXModel and inspect inputs/outputs.
payload = convert_model(ebm, input_df=X_train)
onnx_ml = ONNXModel().setModelPayload(payload)
print("Model inputs:" + str(onnx_ml.getModelInputs()))
print("Model outputs:" + str(onnx_ml.getModelOutputs()))

# Map the model input to the input dataframe's column name (FeedDict), and 
# map the output dataframe's column names to the model outputs (FetchDict)
onnx_ml = (
    onnx_ml.setDeviceType("CPU")
    .setFeedDict(onnx_ml.getModelInputs())    # not sure if this is the correct way?!
    .setFetchDict(onnx_ml.getModelOutputs())  # not sure if this is the correct way?
    .setMiniBatchSize(5000)
)

# Coerce test data features to Spark DataFrame and transform (i.e., compute and add scores)
X_test_sdf = spark.createDataFrame(X_test)
display(onnx_ml.transform(X_test_sdf))

# Error from last line calling `.transform()`:
#
# java.lang.ClassCastException: net.razorvine.pickle.objects.ClassDict cannot be cast to java.lang.String

Other info / logs

java.lang.ClassCastException: net.razorvine.pickle.objects.ClassDict cannot be cast to java.lang.String
---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<command-189663687888411> in <module>
      1 # Coerce test data features to Spark DataFrame and transform (i.e., compute and add scores)
      2 X_test_sdf = spark.createDataFrame(X_test)
----> 3 display(onnx_ml.transform(X_test_sdf))

/databricks/spark/python/pyspark/ml/base.py in transform(self, dataset, params)
    215                 return self.copy(params)._transform(dataset)
    216             else:
--> 217                 return self._transform(dataset)
    218         else:
    219             raise TypeError("Params must be a param map but got %s." % type(params))

/databricks/spark/python/pyspark/ml/wrapper.py in _transform(self, dataset)
    352     def _transform(self, dataset):
    353         self._transfer_params_to_java()
--> 354         return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
    355 
    356 

/databricks/spark/python/lib/py4j-0.10.9.1-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1302 
   1303         answer = self.gateway_client.send_command(command)
-> 1304         return_value = get_return_value(
   1305             answer, self.gateway_client, self.target_id, self.name)
   1306 

/databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
    115     def deco(*a, **kw):
    116         try:
--> 117             return f(*a, **kw)
    118         except py4j.protocol.Py4JJavaError as e:
    119             converted = convert_exception(e.java_exception)

/databricks/spark/python/lib/py4j-0.10.9.1-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    324             value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325             if answer[1] == REFERENCE_TYPE:
--> 326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
    328                     format(target_id, ".", name), value)

Py4JJavaError: An error occurred while calling o906.transform.
: java.lang.ClassCastException: net.razorvine.pickle.objects.ClassDict cannot be cast to java.lang.String
    at com.microsoft.azure.synapse.ml.onnx.ONNXModel.$anonfun$validateSchema$1(ONNXModel.scala:357)
    at com.microsoft.azure.synapse.ml.onnx.ONNXModel.$anonfun$validateSchema$1$adapted(ONNXModel.scala:353)
    at scala.collection.immutable.HashMap$HashMap1.foreach(HashMap.scala:400)
    at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:728)
    at com.microsoft.azure.synapse.ml.onnx.ONNXModel.validateSchema(ONNXModel.scala:353)
    at com.microsoft.azure.synapse.ml.onnx.ONNXModel.$anonfun$transform$1(ONNXModel.scala:213)
    at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logVerb(SynapseMLLogging.scala:81)
    at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logVerb$(SynapseMLLogging.scala:78)
    at com.microsoft.azure.synapse.ml.onnx.ONNXModel.logVerb(ONNXModel.scala:145)
    at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logTransform(SynapseMLLogging.scala:75)
    at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logTransform$(SynapseMLLogging.scala:74)
    at com.microsoft.azure.synapse.ml.onnx.ONNXModel.logTransform(ONNXModel.scala:145)
    at com.microsoft.azure.synapse.ml.onnx.ONNXModel.transform(ONNXModel.scala:211)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
    at py4j.Gateway.invoke(Gateway.java:295)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:251)
    at java.lang.Thread.run(Thread.java:750)

What component(s) does this bug affect?

What language(s) does this bug affect?

What integration(s) does this bug affect?

github-actions[bot] commented 1 year ago

Hey @brandongreenwell-8451 :wave:! Thank you so much for reporting the issue/feature request :rotating_light:. Someone from SynapseML Team will be looking to triage this issue soon. We appreciate your patience.

svotaw commented 1 year ago

asked Jason to take a look

memoryz commented 1 year ago

ack. I'll take a look shortly.

memoryz commented 1 year ago

Hi @brandongreenwell-8451, I had some trouble running your code in Synapse, and I don't have access to Databricks. But reading your code, I think the problem is here

onnx_ml = (
    onnx_ml.setDeviceType("CPU")
    .setFeedDict(onnx_ml.getModelInputs())    # not sure if this is the correct way?!
    .setFetchDict(onnx_ml.getModelOutputs())  # not sure if this is the correct way?
)

The setFeedDict method sets up the mapping between the ONNX input nodes and the columns of the input Spark dataframe (X_test_sdf), and the setFetchDict method sets up the mapping between the ONNX output nodes and the columns of the output Spark dataframe.

The setFeedDict method expects a dict with keys being the ONNX input node names, and the values being the corresponding column name from the input Spark dataframe. The setFetchDict method expects a dict with keys being the column name from the output Spark dataframe, and values being the corresponding ONNX output node name.

For example, if the ONNX model's input node name is "input", and output nodes with names "probabilities" and "label", then we can setup the model with the following:

onnx_ml = (
    onnx_ml.setDeviceType("CPU")
    .setFeedDict({"input": "features"})
    .setFetchDict({"probability": "probabilities", "prediction": "label"})
    .setMiniBatchSize(5000)
)

If you still have trouble, please show the output from these two lines:

print("Model inputs:" + str(onnx_ml.getModelInputs()))
print("Model outputs:" + str(onnx_ml.getModelOutputs()))

and also show the schema of X_test_sdf by:

X_test_sdf.printSchema()
brandongreenwell-8451 commented 1 year ago

Hi @memoryz, thank you for taking a look. I am still having trouble, however, so here's the output you requested:

print("Model inputs:" + str(onnx_ml.getModelInputs()))

Model inputs:{'HoursPerWeek': NodeInfo(name=HoursPerWeek,info=TensorInfo(shape=[-1], type=INT64)), 'CapitalGain': NodeInfo(name=CapitalGain,info=TensorInfo(shape=[-1], type=INT64)), 'Education': NodeInfo(name=Education,info=TensorInfo(shape=[-1], type=STRING)), 'WorkClass': NodeInfo(name=WorkClass,info=TensorInfo(shape=[-1], type=STRING)), 'Race': NodeInfo(name=Race,info=TensorInfo(shape=[-1], type=STRING)), 'Age': NodeInfo(name=Age,info=TensorInfo(shape=[-1], type=INT64)), 'Relationship': NodeInfo(name=Relationship,info=TensorInfo(shape=[-1], type=STRING)), 'EducationNum': NodeInfo(name=EducationNum,info=TensorInfo(shape=[-1], type=INT64)), 'NativeCountry': NodeInfo(name=NativeCountry,info=TensorInfo(shape=[-1], type=STRING)), 'CapitalLoss': NodeInfo(name=CapitalLoss,info=TensorInfo(shape=[-1], type=INT64)), 'Gender': NodeInfo(name=Gender,info=TensorInfo(shape=[-1], type=STRING)), 'MaritalStatus': NodeInfo(name=MaritalStatus,info=TensorInfo(shape=[-1], type=STRING)), 'Occupation': NodeInfo(name=Occupation,info=TensorInfo(shape=[-1], type=STRING)), 'fnlwgt': NodeInfo(name=fnlwgt,info=TensorInfo(shape=[-1], type=INT64))}
print("Model outputs:" + str(onnx_ml.getModelOutputs()))

Model outputs:{'predict_proba_0': NodeInfo(name=predict_proba_0,info=TensorInfo(shape=[-1,2], type=FLOAT))}
print(X_test_sdf.printSchema())

root
 |-- Age: long (nullable = true)
 |-- WorkClass: string (nullable = true)
 |-- fnlwgt: long (nullable = true)
 |-- Education: string (nullable = true)
 |-- EducationNum: long (nullable = true)
 |-- MaritalStatus: string (nullable = true)
 |-- Occupation: string (nullable = true)
 |-- Relationship: string (nullable = true)
 |-- Race: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- CapitalGain: long (nullable = true)
 |-- CapitalLoss: long (nullable = true)
 |-- HoursPerWeek: long (nullable = true)
 |-- NativeCountry: string (nullable = true)

None

Based on your comment, I did change these parameters to

onnx_ml = (
    onnx_ml.setDeviceType("CPU")
    .setFeedDict(dict(zip(X_test_sdf.columns, X_test_sdf.columns)))  # keys/values match column names 
    .setFetchDict({'predict_proba_0': 'predict_proba_0'}) 
    .setMiniBatchSize(5000)
)

And the .transform() call still failed, but not right away this time. Here's the new error:

IllegalArgumentException: **Image batch is not a sequence**

org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 2.0 failed 4 times, most recent failure: Lost task 0.3 in stage 2.0 (TID 11) (10.2.64.8 executor 0): java.lang.IllegalArgumentException: Image batch is not a sequence
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1(ONNXUtils.scala:130)
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1$adapted(ONNXUtils.scala:128)
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.validateBatchShapes(ONNXUtils.scala:128)
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.createTensor(ONNXUtils.scala:104)
    at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$2(ONNXRuntime.scala:74)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
    at scala.collection.Iterator.foreach(Iterator.scala:943)
    at scala.collection.Iterator.foreach$(Iterator.scala:943)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
    at scala.collection.IterableLike.foreach(IterableLike.scala:74)
    at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
    at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
    at scala.collection.TraversableLike.map(TraversableLike.scala:286)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
    at scala.collection.AbstractTraversable.map(Traversable.scala:108)
    at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$1(ONNXRuntime.scala:67)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
    at com.microsoft.azure.synapse.ml.core.utils.CloseableIterator.next(CloseableIterator.scala:11)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
    at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:155)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:156)
    at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:125)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.Task.run(Task.scala:95)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:832)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1681)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:835)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:690)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:750)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3029)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2976)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2970)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2970)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1390)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1390)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1390)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3238)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3179)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3167)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1152)
    at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2657)
    at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:266)
    at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:276)
    at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:81)
    at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:87)
    at org.apache.spark.sql.execution.collect.InternalRowFormat$.collect(cachedSparkResults.scala:75)
    at org.apache.spark.sql.execution.collect.InternalRowFormat$.collect(cachedSparkResults.scala:62)
    at org.apache.spark.sql.execution.ResultCacheManager.collectResult$1(ResultCacheManager.scala:573)
    at org.apache.spark.sql.execution.ResultCacheManager.computeResult(ResultCacheManager.scala:582)
    at org.apache.spark.sql.execution.ResultCacheManager.$anonfun$getOrComputeResultInternal$1(ResultCacheManager.scala:528)
    at scala.Option.getOrElse(Option.scala:189)
    at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResultInternal(ResultCacheManager.scala:527)
    at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:424)
    at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:403)
    at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:424)
    at org.apache.spark.sql.Dataset.collectResult(Dataset.scala:3153)
    at org.apache.spark.sql.Dataset.$anonfun$collectResult$1(Dataset.scala:3144)
    at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3951)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withCustomExecutionEnv$8(SQLExecution.scala:239)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:386)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withCustomExecutionEnv$1(SQLExecution.scala:186)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:968)
    at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:141)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:336)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3949)
    at org.apache.spark.sql.Dataset.collectResult(Dataset.scala:3143)
    at com.databricks.backend.daemon.driver.OutputAggregator$.withOutputAggregation0(OutputAggregator.scala:266)
    at com.databricks.backend.daemon.driver.OutputAggregator$.withOutputAggregation(OutputAggregator.scala:100)
    at com.databricks.backend.daemon.driver.PythonDriverLocalBase.generateTableResult(PythonDriverLocalBase.scala:723)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.computeListResultsItem(PythonDriverLocal.scala:627)
    at com.databricks.backend.daemon.driver.PythonDriverLocalBase.genListResults(PythonDriverLocalBase.scala:630)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.$anonfun$getResultBufferInternal$1(PythonDriverLocal.scala:682)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.withInterpLock(PythonDriverLocal.scala:563)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.getResultBufferInternal(PythonDriverLocal.scala:642)
    at com.databricks.backend.daemon.driver.DriverLocal.getResultBuffer(DriverLocal.scala:748)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.outputSuccess(PythonDriverLocal.scala:605)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.$anonfun$repl$6(PythonDriverLocal.scala:223)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.withInterpLock(PythonDriverLocal.scala:563)
    at com.databricks.backend.daemon.driver.PythonDriverLocal.repl(PythonDriverLocal.scala:210)
    at com.databricks.backend.daemon.driver.DriverLocal.$anonfun$execute$13(DriverLocal.scala:634)
    at com.databricks.logging.Log4jUsageLoggingShim$.$anonfun$withAttributionContext$1(Log4jUsageLoggingShim.scala:33)
    at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
    at com.databricks.logging.AttributionContext$.withValue(AttributionContext.scala:94)
    at com.databricks.logging.Log4jUsageLoggingShim$.withAttributionContext(Log4jUsageLoggingShim.scala:31)
    at com.databricks.logging.UsageLogging.withAttributionContext(UsageLogging.scala:205)
    at com.databricks.logging.UsageLogging.withAttributionContext$(UsageLogging.scala:204)
    at com.databricks.backend.daemon.driver.DriverLocal.withAttributionContext(DriverLocal.scala:59)
    at com.databricks.logging.UsageLogging.withAttributionTags(UsageLogging.scala:240)
    at com.databricks.logging.UsageLogging.withAttributionTags$(UsageLogging.scala:225)
    at com.databricks.backend.daemon.driver.DriverLocal.withAttributionTags(DriverLocal.scala:59)
    at com.databricks.backend.daemon.driver.DriverLocal.execute(DriverLocal.scala:611)
    at com.databricks.backend.daemon.driver.DriverWrapper.$anonfun$tryExecutingCommand$1(DriverWrapper.scala:615)
    at scala.util.Try$.apply(Try.scala:213)
    at com.databricks.backend.daemon.driver.DriverWrapper.tryExecutingCommand(DriverWrapper.scala:607)
    at com.databricks.backend.daemon.driver.DriverWrapper.executeCommandAndGetError(DriverWrapper.scala:526)
    at com.databricks.backend.daemon.driver.DriverWrapper.executeCommand(DriverWrapper.scala:561)
    at com.databricks.backend.daemon.driver.DriverWrapper.runInnerLoop(DriverWrapper.scala:431)
    at com.databricks.backend.daemon.driver.DriverWrapper.runInner(DriverWrapper.scala:374)
    at com.databricks.backend.daemon.driver.DriverWrapper.run(DriverWrapper.scala:225)
    at java.lang.Thread.run(Thread.java:750)
Caused by: java.lang.IllegalArgumentException: Image batch is not a sequence
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1(ONNXUtils.scala:130)
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.$anonfun$validateBatchShapes$1$adapted(ONNXUtils.scala:128)
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.validateBatchShapes(ONNXUtils.scala:128)
    at com.microsoft.azure.synapse.ml.onnx.ONNXUtils$.createTensor(ONNXUtils.scala:104)
    at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$2(ONNXRuntime.scala:74)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
    at scala.collection.Iterator.foreach(Iterator.scala:943)
    at scala.collection.Iterator.foreach$(Iterator.scala:943)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
    at scala.collection.IterableLike.foreach(IterableLike.scala:74)
    at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
    at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
    at scala.collection.TraversableLike.map(TraversableLike.scala:286)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
    at scala.collection.AbstractTraversable.map(Traversable.scala:108)
    at com.microsoft.azure.synapse.ml.onnx.ONNXRuntime$.$anonfun$applyModel$1(ONNXRuntime.scala:67)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
    at com.microsoft.azure.synapse.ml.core.utils.CloseableIterator.next(CloseableIterator.scala:11)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
    at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:155)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:156)
    at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:125)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.Task.run(Task.scala:95)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:832)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1681)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:835)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:690)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more
MainRo commented 1 year ago

Hello @brandongreenwell-8451 the output shape of EBM with probabilities is [-1, 2]. I do not know how SynapseML maps a multi-dimension output to a column. Maybe this is the cause of the issue. Can you serialize the model without predict_proba to test if it helps. The output will only be the predicted class instead of the proba for each class (the name of the output will be different).

brandongreenwell-8451 commented 1 year ago

Hi @MainRo, same error when serializing without the predict_proba option and updating the .setFetchDict argument accordingly.

print("Model outputs:" + str(onnx_ml.getModelOutputs()))

Model outputs:{'predict_0': NodeInfo(name=predict_0,info=TensorInfo(shape=[-1], type=INT64))}
brandongreenwell-8451 commented 1 year ago

Since this is a binary outcome, the goal is to be able to predict probabilities, as opposed to class labels.

memoryz commented 1 year ago

@brandongreenwell-8451 I can finally reproduce the issue and have a PR #1906 sent out.

bgreenwell commented 1 year ago

@memoryz this is great to hear! Will this also fix the case for predict_proba=True, which gives shape [-1,2]?

memoryz commented 1 year ago

The case for predict_proba=True is not impacted by this bug or the bug fix. When the output shape is [-1, 2], the output column will contain arrays of length 2. To illustrate the usage:

    val onnxGH1902 = new ONNXModel()
      .setModelLocation(model.getPath)
      .setDeviceType("CPU")
      .setFeedDict(featuresAdultsIncome.map(v => (v, v)).toMap)
      .setFetchDict(Map("probability" -> "probabilities"))
      .setArgMaxDict(Map("probability" -> "prediction"))
      .setMiniBatchSize(5000)

    val Array(row1, row2) = onnxGH1902.transform(testDfGH1902)
      .select("probability", "prediction")
      .orderBy(col("prediction"))
      .as[(Seq[Double], Double)]
      .collect()

    assert(row1._1 === Seq(0.9343283176422119, 0.0656716451048851))
    assert(row1._2 === 0.0)

    assert(row2._1 === Seq(0.16954122483730316, 0.8304587006568909))
    assert(row2._2 === 1.0)
memoryz commented 1 year ago

@brandongreenwell-8451 to test the fix, please use this coordinate: com.microsoft.azure:synapseml_2.12:0.11.0-36-86ce42ba-SNAPSHOT with the resolver: https://mmlspark.azureedge.net/maven. This is from the PR build.

brandongreenwell-8451 commented 1 year ago

Huzzah!! @memoryz, I am indeed getting probabilities now! This is great, and our team truly appreciates the effort!

fralik commented 1 year ago

@memoryz , is it possible to use a package with this fix in Python? I think I hit the same issue where inputs of my model are:

Model inputs: {
  'attention_mask': NodeInfo(name=attention_mask, info=TensorInfo(shape=[-1,-1], type=INT64)),
  'input_ids': NodeInfo(name=input_ids, info=TensorInfo(shape=[-1,-1], type=INT64)),
  'token_type_ids': NodeInfo(name=token_type_ids, info=TensorInfo(shape=[-1,-1], type=INT64))
}

Alternatively, is it possible to modify the shape via synapse.ml.onnx syntax as I think that my shape is in reality fixed to 512?

brandongreenwell-8451 commented 1 year ago

@fralik my original issue was in Python (see the original reproducible example), but this seemingly fixed it for me?

memoryz commented 1 year ago

@memoryz , is it possible to use a package with this fix in Python? I think I hit the same issue where inputs of my model are:

Model inputs: {
  'attention_mask': NodeInfo(name=attention_mask, info=TensorInfo(shape=[-1,-1], type=INT64)),
  'input_ids': NodeInfo(name=input_ids, info=TensorInfo(shape=[-1,-1], type=INT64)),
  'token_type_ids': NodeInfo(name=token_type_ids, info=TensorInfo(shape=[-1,-1], type=INT64))
}

Alternatively, is it possible to modify the shape via synapse.ml.onnx syntax as I think that my shape is in reality fixed to 512?

Which version did you try? Please use this coordinate: com.microsoft.azure:synapseml_2.12:0.11.0-36-86ce42ba-SNAPSHOT with the resolver: https://mmlspark.azureedge.net/maven. The fix is in the master branch, but not released yet.

Shape modification is not supported.

fralik commented 1 year ago

I just created a synapse cluster with Spark 3.3 in Azure. Here is the list of installed packages with synapse in them:

azure-synapse-ml-predict==1.0.0
azureml-synapse==0.0.1
synapseml-cognitive==0.10.1.dev1
synapseml-core==0.10.1.dev1
synapseml-deep-learning==0.10.1.dev1
synapseml-internal==0.0.0.dev1
synapseml-lightgbm==0.10.1.dev1
synapseml-opencv==0.10.1.dev1
synapseml-vw==0.10.1.dev1
memoryz commented 1 year ago

@fralik Can you please try the following configuration for Synapse 3.2 pool (until we release the next version), by placing the following in the first cell of your notebook:

%%configure -f
{
  "name": "synapseml",
  "conf": {
      "spark.jars.packages": "com.microsoft.azure:synapseml_2.12:0.11.0-36-86ce42ba-SNAPSHOT,org.apache.spark:spark-avro_2.12:3.3.1",
      "spark.jars.repositories": "https://mmlspark.azureedge.net/maven",
      "spark.jars.excludes": "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalactic:scalactic_2.12,org.scalatest:scalatest_2.12,com.fasterxml.jackson.core:jackson-databind",
      "spark.yarn.user.classpath.first": "true",
      "spark.sql.parquet.enableVectorizedReader": "false",
      "spark.sql.legacy.replaceDatabricksSparkAvro.enabled": "true"
  }
}
fralik commented 1 year ago

@memoryz , thanks so much! It did work with this version.

When do you plan to release the next version? I am actually using Substrate in O365 where Synapse and Spark pools are provisioned for us. Usage of Azure Synapse was only for a quick proof of concept, but customer data can only be handled via Substrate. In Substrate, I won't be able to configure the pool myself and would need to ask for onboarding of the new version.

memoryz commented 1 year ago

@mhamilton723, when will we release the next version and get it included into Synapse runtime?

@fralik, AFAIK, it takes some time to have the new release included into Synapse runtime after the new release happens. Meanwhile, the %%configure cell can be added to the notebook, so it doesn't require you to modify the pool directly. Would this work for you for now?

fralik commented 1 year ago

Thanks @memoryz . I think I am good for now.

memoryz commented 1 year ago

Thanks for confirmation. Closing.