JohnSnowLabs / spark-nlp

State of the Art Natural Language Processing
https://sparknlp.org/
Apache License 2.0
3.88k stars 712 forks source link

SpanBertCorefModel failed to execute udf(HasSimpleAnnotate) #10727

Open jenghub opened 2 years ago

jenghub commented 2 years ago

Attempting to perform coreference resolution on a large dataframe containing English-language documents.

Description

Attempting to run SpanBertCoref() in a pipeline

document_assembler = DocumentAssembler().setInputCol("extracted_doc").setOutputCol("document")
sentence_detector = SentenceDetectorDLModel.pretrained("sentence_detector_dl", "en").setInputCols(["document"]).setOutputCol("sentences")
tokenizer = Tokenizer().setInputCols(["sentences"]).setOutputCol("tokens")
corefResolution = SpanBertCorefModel().pretrained("spanbert_base_coref").setInputCols(["sentences", "tokens"]).setOutputCol("corefs")
# coref_pipeline = Pipeline(stages=[document_assembler, sentence_detector, tokenizer]) # works fine
coref_pipeline = Pipeline(stages=[document_assembler, sentence_detector, tokenizer, corefResolution])

coref_model = coref_pipeline.fit(df)
test_result = coref_model.transform(df)
display(test_result)

Above block downloads the models correctly but then errors out on the coreference resolution step

sentence_detector_dl download started this may take some time.
Approximate size to download 354.6 KB
[OK!]
spanbert_base_coref download started this may take some time.
Approximate size to download 540.1 MB
[OK!]

Error:

~/cluster-env/clonedenv/lib/python3.8/site-packages/notebookutils/visualization/display.py in display(data, summary)
    138         log4jLogger\
    139             .error(f"display failed with error, language: python, error: {err}")
--> 140         raise err
    141 
    142     log4jLogger\

~/cluster-env/clonedenv/lib/python3.8/site-packages/notebookutils/visualization/display.py in display(data, summary)
    118                     from IPython.display import publish_display_data
    119                     publish_display_data({
--> 120                         "application/vnd.synapse.display-widget+json": sc._jvm.display.getDisplayResultForIPython(df._jdf, summary)
    121                     })
    122                 else:

~/cluster-env/clonedenv/lib/python3.8/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1302 
   1303         answer = self.gateway_client.send_command(command)
-> 1304         return_value = get_return_value(
   1305             answer, self.gateway_client, self.target_id, self.name)
   1306 

/opt/spark/python/lib/pyspark.zip/pyspark/sql/utils.py in deco(*a, **kw)
    109     def deco(*a, **kw):
    110         try:
--> 111             return f(*a, **kw)
    112         except py4j.protocol.Py4JJavaError as e:
    113             converted = convert_exception(e.java_exception)

~/cluster-env/clonedenv/lib/python3.8/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    324             value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325             if answer[1] == REFERENCE_TYPE:
--> 326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
    328                     format(target_id, ".", name), value)

Py4JJavaError: An error occurred while calling z:com.microsoft.spark.notebook.visualization.display.getDisplayResultForIPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 39.0 failed 4 times, most recent failure: Lost task 0.3 in stage 39.0 (TID 148) (vm-e5099158 executor 1): org.apache.spark.SparkException: Failed to execute user defined function(HasSimpleAnnotate$$Lambda$6097/176097372: (array<array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>>) => array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:762)
    at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:383)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:905)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:905)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:57)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:131)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.IndexOutOfBoundsException: 1
    at scala.collection.mutable.ResizableArray.apply(ResizableArray.scala:46)
    at scala.collection.mutable.ResizableArray.apply$(ResizableArray.scala:45)
    at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:49)
    at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1(TensorflowSpanBertCoref.scala:52)
    at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1$adapted(TensorflowSpanBertCoref.scala:34)
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
    at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.predict(TensorflowSpanBertCoref.scala:34)
    at com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel.annotate(SpanBertCorefModel.scala:327)
    at com.johnsnowlabs.nlp.HasSimpleAnnotate.$anonfun$dfAnnotate$1(HasSimpleAnnotate.scala:46)
    ... 17 more

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2313)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2262)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2261)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2261)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1132)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1132)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1132)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2500)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2442)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2431)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:908)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2301)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2322)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2341)
    at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:510)
    at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:463)
    at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:47)
    at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3709)
    at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2735)
    at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3700)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:107)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:181)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:94)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:68)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3698)
    at org.apache.spark.sql.Dataset.head(Dataset.scala:2735)
    at org.apache.spark.sql.Dataset.take(Dataset.scala:2942)
    at org.apache.spark.sql.GetRowsHelper$.getRowsInJsonString(GetRowsHelper.scala:51)
    at com.microsoft.spark.notebook.visualization.display$.generateTableConfig(Display.scala:454)
    at com.microsoft.spark.notebook.visualization.display$.exec(Display.scala:189)
    at com.microsoft.spark.notebook.visualization.display$.getDisplayResultInternal(Display.scala:139)
    at com.microsoft.spark.notebook.visualization.display$.getDisplayResultForIPython(Display.scala:80)
    at com.microsoft.spark.notebook.visualization.display.getDisplayResultForIPython(Display.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function(HasSimpleAnnotate$$Lambda$6097/176097372: (array<array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>>) => array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:762)
    at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:383)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:905)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:905)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:57)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:131)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more
Caused by: java.lang.IndexOutOfBoundsException: 1
    at scala.collection.mutable.ResizableArray.apply(ResizableArray.scala:46)
    at scala.collection.mutable.ResizableArray.apply$(ResizableArray.scala:45)
    at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:49)
    at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1(TensorflowSpanBertCoref.scala:52)
    at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1$adapted(TensorflowSpanBertCoref.scala:34)
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
    at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.predict(TensorflowSpanBertCoref.scala:34)
    at com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel.annotate(SpanBertCorefModel.scala:327)
    at com.johnsnowlabs.nlp.HasSimpleAnnotate.$anonfun$dfAnnotate$1(HasSimpleAnnotate.scala:46)
    ... 17 more

Possible Solution

The below demo code works properly. Are there potentially issues in the tokens or sentences being detected? Some of the documents can have some artifacts as they may be translated into English from an earlier preprocessing step.

data = spark.createDataFrame([["John told Mary he would like to borrow a book from her."]]).toDF("text")
document_assembler = DocumentAssembler().setInputCol("text").setOutputCol("document")
sentence_detector = SentenceDetector().setInputCols(["document"]).setOutputCol("sentences")
tokenizer = Tokenizer().setInputCols(["sentences"]).setOutputCol("tokens")
corefResolution = SpanBertCorefModel().pretrained("spanbert_base_coref").setInputCols(["sentences", "tokens"]).setOutputCol("corefs")
pipeline = Pipeline(stages=[document_assembler, sentence_detector, tokenizer, corefResolution])

model = pipeline.fit(self.data)

model.transform(self.data).selectExpr("explode(corefs) AS coref").selectExpr("coref.result as token", "coref.metadata").show(truncate=False)

Context

Attempting to use a spark-based coreference resolution as other libraries are not compatible with my spark environment or are slow and error out.

Your Environment

maziyarpanahi commented 2 years ago

Thanks for reporting this issue @jenghub. This feature was just released recently and I will ask @vankov to have a look at this.

teowz46 commented 1 year ago

Any updates for this? I'm encountering pretty much the same error, except that z:com.microsoft.spark.notebook.visualization.display.getDisplayResultForIPython in OP's error is oXXXX.showString for me (where XXXX is some number). I noticed that the error starts occurring when the input has more than around 120 words.

sparknlp.version(): 5.1.4 spark.version: 3.5.0 java -version: 1.8.0_392 Installed via pip install OS: Ubuntu 20.04.5 LTS Other details: this is being done offline, loaded spark.jars in the SparkSesssion from local file and loaded the model from the copy on Models Hub.

khaibenz commented 7 months ago

I am facing the same issue, any updates on this?