eto-ai / rikai

Parquet-based ML data format optimized for working with unstructured data
https://rikai.readthedocs.io/en/latest/
Apache License 2.0
136 stars 19 forks source link

Failed to serialize and deserialize mask #612

Closed da-liii closed 2 years ago

da-liii commented 2 years ago

The full ModelType: https://github.com/da-tubi/rikai-ocr/blob/main/rikai/contrib/ocr/keras.py

I'm using array<struct<text:string, mask:box2d>> as a workaround.

See https://github.com/da-tubi/rikai-ocr/blob/f99651863803ea43e5f97a5a9019fcbf6e0e3048/rikai/contrib/ocr/keras.py#L63-L64

If you use array<struct<text:string, mask:mask>>, you will fail during ser/deser because of numpy ndarray.

Just replace the workaround schema with array<struct<text:string, mask:mask>> and the convert method with convert_pred_groups_for_rikai to reproduce it.

Here is the related stackoverflow question: https://stackoverflow.com/questions/38984775/spark-errorexpected-zero-arguments-for-construction-of-classdict-for-numpy-cor

eddyxu commented 2 years ago

could this help? https://rikai.readthedocs.io/en/latest/numpy.html

What are the actual meanings of the OCR model outputs? Storing the output in numpy usually makes it not analyze-able via SQL.

We did convert tensors / numpy array to (label, score, box) in torch / tensorflow vision models for the same reason.

da-liii commented 2 years ago

Here is the reproducer branch: https://github.com/da-tubi/rikai-ocr/pull/1 (Please use the KerasOCRikai.ipynb to reproduce it)

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
Input In [3], in <cell line: 2>()
      1 image_uri = 'https://www.rochester.edu/newscenter/wp-content/uploads/2022/03/fea-ukraine-russian-war.jpg'
----> 2 df = spark.sql(f"""
      3 select pred.text, pred.mask
      4 from (
      5   select explode(ML_PREDICT(keras_ocr, to_image('{image_uri}'))) as pred
      6 )
      7 """).toPandas()

File ~/.pyenv/versions/3.8.10/envs/rikai-ocr/lib/python3.8/site-packages/pyspark/sql/pandas/conversion.py:141, in PandasConversionMixin.toPandas(self)
    138             raise
    140 # Below is toPandas without Arrow optimization.
--> 141 pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
    142 column_counter = Counter(self.columns)
    144 dtype = [None] * len(self.schema)

File ~/.pyenv/versions/3.8.10/envs/rikai-ocr/lib/python3.8/site-packages/pyspark/sql/dataframe.py:677, in DataFrame.collect(self)
    667 """Returns all the records as a list of :class:`Row`.
    668 
    669 .. versionadded:: 1.3.0
   (...)
    674 [Row(age=2, name='Alice'), Row(age=5, name='Bob')]
    675 """
    676 with SCCallSiteSync(self._sc) as css:
--> 677     sock_info = self._jdf.collectToPython()
    678 return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))

File ~/.pyenv/versions/3.8.10/envs/rikai-ocr/lib/python3.8/site-packages/py4j/java_gateway.py:1304, in JavaMember.__call__(self, *args)
   1298 command = proto.CALL_COMMAND_NAME +\
   1299     self.command_header +\
   1300     args_command +\
   1301     proto.END_COMMAND_PART
   1303 answer = self.gateway_client.send_command(command)
-> 1304 return_value = get_return_value(
   1305     answer, self.gateway_client, self.target_id, self.name)
   1307 for temp_arg in temp_args:
   1308     temp_arg._detach()

File ~/.pyenv/versions/3.8.10/envs/rikai-ocr/lib/python3.8/site-packages/pyspark/sql/utils.py:111, in capture_sql_exception.<locals>.deco(*a, **kw)
    109 def deco(*a, **kw):
    110     try:
--> 111         return f(*a, **kw)
    112     except py4j.protocol.Py4JJavaError as e:
    113         converted = convert_exception(e.java_exception)

File ~/.pyenv/versions/3.8.10/envs/rikai-ocr/lib/python3.8/site-packages/py4j/protocol.py:326, in get_return_value(answer, gateway_client, target_id, name)
    324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325 if answer[1] == REFERENCE_TYPE:
--> 326     raise Py4JJavaError(
    327         "An error occurred while calling {0}{1}{2}.\n".
    328         format(target_id, ".", name), value)
    329 else:
    330     raise Py4JError(
    331         "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
    332         format(target_id, ".", name, value))

Py4JJavaError: An error occurred while calling o510.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0) (192.168.31.58 executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)
    at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
    at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:773)
    at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:213)
    at net.razorvine.pickle.Unpickler.load(Unpickler.java:123)
    at net.razorvine.pickle.Unpickler.loads(Unpickler.java:136)
    at org.apache.spark.sql.execution.python.BatchEvalPythonExec.$anonfun$evaluate$6(BatchEvalPythonExec.scala:94)
    at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:484)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:490)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1209)
    at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:1215)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
    at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:307)
    at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.writeIteratorToStream(PythonUDFRunner.scala:53)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:397)
    at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1996)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:232)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2258)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2207)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2206)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2206)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1079)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1079)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1079)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2445)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2387)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2376)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2196)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2217)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2236)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2261)
    at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1030)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:414)
    at org.apache.spark.rdd.RDD.collect(RDD.scala:1029)
    at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:390)
    at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:3519)
    at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3687)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3685)
    at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:3516)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)
Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)
    at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
    at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:773)
    at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:213)
    at net.razorvine.pickle.Unpickler.load(Unpickler.java:123)
    at net.razorvine.pickle.Unpickler.loads(Unpickler.java:136)
    at org.apache.spark.sql.execution.python.BatchEvalPythonExec.$anonfun$evaluate$6(BatchEvalPythonExec.scala:94)
    at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:484)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:490)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1209)
    at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:1215)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
    at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:307)
    at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.writeIteratorToStream(PythonUDFRunner.scala:53)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:397)
    at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1996)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:232)
da-liii commented 2 years ago

How to reproduce the same error. See https://issues.apache.org/jira/browse/SPARK-12157

import pyspark.sql.types as T
import pyspark.sql.functions as F
from pyspark.sql import Row
import numpy as np

argmax = F.udf(lambda x: np.argmax(x), T.IntegerType())

df = sqlContext.createDataFrame([Row(array=[1,2,3])])
df.select(argmax("array")).show()
da-liii commented 2 years ago

Well, it turns out that I made a mistake in rikai-ocr, now, I've fixed it.

See https://github.com/da-tubi/rikai-ocr/pull/2/files

Just turn the unexpected numpy float32 type to python float type.