JohnSnowLabs / spark-nlp

State of the Art Natural Language Processing
https://sparknlp.org/
Apache License 2.0
3.86k stars 711 forks source link

BART Summarization max tokens? #13829

Closed clabornd closed 1 year ago

clabornd commented 1 year ago

Is there an existing issue for this?

What are you working on?

I am trying to summarize potentially long texts with distilbart_xsum_12_6.

Current Behavior

Currently I get an error on long texts:

TFInvalidArgumentException: {{function_node __inference_pruned_56247}} {{function_node __inference_pruned_56247}} indices[1054] = 1056 is not in [0, 1026)

(full error at the bottom)

Expected Behavior

Maybe not expected behavior, but I always assumed something was going on under the hood to handle texts longer than a particular model's max context length. Does no such mechanism exist for BART, or in general?

Steps To Reproduce

A rough example:

from sparknlp.base import *
from sparknlp.annotator import *

documentAssembler = DocumentAssembler() \
    .setInputCol('text') \
    .setOutputCol('document')

bart = BartTransformer.pretrained("distilbart_xsum_12_6") \
            .setTask("summarize:") \
            .setMaxOutputLength(250) \
            .setInputCols(["document"]) \
            .setOutputCol("summary")

nlp_pipeline = Pipeline(stages=[
    documentAssembler, 
    bart
])

pipeline_model = nlp_pipeline.fit(spark.createDataFrame([['']]).toDF('text'))

sentences = [
  [" ".join(["word"]*1027)]
  for i in range(3)
]

data = spark.createDataFrame(sentences).toDF("text")
data.show()

display(pipeline_model.transform(data))

Spark NLP version and Apache Spark

sparknlp 4.4.0 spark 3.3.2


org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 104.0 failed 4 times, most recent failure: Lost task 1.3 in stage 104.0 (TID 1844) (10.139.64.7 executor 15): org.tensorflow.exceptions.TFInvalidArgumentException: {{function_node __inference_pruned_56247}} {{function_node __inference_pruned_56247}} indices[1054] = 1056 is not in [0, 1026)
     [[{{node encoder/embed_positions/embedding_lookup}}]]
     [[StatefulPartitionedCall_1/StatefulPartitionedCall/StatefulPartitionedCall]]
    at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:87)
    at org.tensorflow.Session.run(Session.java:850)
    at org.tensorflow.Session.access$300(Session.java:82)
    at org.tensorflow.Session$Runner.runHelper(Session.java:552)
    at org.tensorflow.Session$Runner.runNoInit(Session.java:499)
    at org.tensorflow.Session$Runner.run(Session.java:495)
    at com.johnsnowlabs.ml.ai.Bart.tag(Bart.scala:169)
    at com.johnsnowlabs.ml.ai.Bart.$anonfun$predict$1(Bart.scala:763)
    at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
    at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
    at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
    at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
    at com.johnsnowlabs.ml.ai.Bart.predict(Bart.scala:749)
    at com.johnsnowlabs.nlp.annotators.seq2seq.BartTransformer.batchAnnotate(BartTransformer.scala:487)
    at com.johnsnowlabs.nlp.HasBatchedAnnotate.$anonfun$batchProcess$1(HasBatchedAnnotate.scala:59)
    at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage4.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:761)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:82)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:208)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:174)
    at org.apache.spark.scheduler.Task.$anonfun$run$5(Task.scala:142)
    at com.databricks.unity.EmptyHandle$.runWithAndClose(UCSHandle.scala:125)
    at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:142)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.Task.run(Task.scala:97)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:904)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1713)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:907)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:761)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:750)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3377)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3309)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3300)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3300)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1429)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1429)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1429)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3589)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3527)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3515)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1178)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1166)
    at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2737)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$runSparkJobs$1(Collector.scala:349)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
    at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:293)
    at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:377)
    at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:128)
    at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:135)
    at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:123)
    at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:111)
    at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:93)
    at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$1(ResultCacheManager.scala:537)
    at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
    at org.apache.spark.sql.execution.qrc.ResultCacheManager.collectResult$1(ResultCacheManager.scala:529)
    at org.apache.spark.sql.execution.qrc.ResultCacheManager.computeResult(ResultCacheManager.scala:549)
    at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$getOrComputeResultInternal$1(ResultCacheManager.scala:402)
    at scala.Option.getOrElse(Option.scala:189)
    at org.apache.spark.sql.execution.qrc.ResultCacheManager.getOrComputeResultInternal(ResultCacheManager.scala:395)
    at org.apache.spark.sql.execution.qrc.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:289)
    at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeCollectResult$1(SparkPlan.scala:506)
    at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
    at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:503)
    at org.apache.spark.sql.Dataset.collectResult(Dataset.scala:3453)
    at org.apache.spark.sql.Dataset.$anonfun$collectResult$1(Dataset.scala:3444)
    at org.apache.spark.sql.Dataset.$anonfun$withAction$3(Dataset.scala:4368)
    at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:809)
    at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4366)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withCustomExecutionEnv$8(SQLExecution.scala:227)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:410)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withCustomExecutionEnv$1(SQLExecution.scala:172)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1038)
    at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:122)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:360)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4366)
    at org.apache.spark.sql.Dataset.collectResult(Dataset.scala:3443)
    at com.databricks.backend.daemon.driver.OutputAggregator$.withOutputAggregation0(OutputAggregator.scala:267)
    at com.databricks.backend.daemon.driver.OutputAggregator$.withOutputAggregation(OutputAggregator.scala:101)
    at com.databricks.backend.daemon.driver.PythonDriverLocalBase.generateTableResult(PythonDriverLocalBase.scala:723)
    at com.databricks.backend.daemon.driver.JupyterDriverLocal.computeListResultsItem(JupyterDriverLocal.scala:839)
    at com.databricks.backend.daemon.driver.JupyterDriverLocal$JupyterEntryPoint.addCustomDisplayData(JupyterDriverLocal.scala:258)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
    at py4j.Gateway.invoke(Gateway.java:306)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:195)
    at py4j.ClientServerConnection.run(ClientServerConnection.java:115)
    at java.lang.Thread.run(Thread.java:750)
Caused by: org.tensorflow.exceptions.TFInvalidArgumentException: {{function_node __inference_pruned_56247}} {{function_node __inference_pruned_56247}} indices[1054] = 1056 is not in [0, 1026)
     [[{{node encoder/embed_positions/embedding_lookup}}]]
     [[StatefulPartitionedCall_1/StatefulPartitionedCall/StatefulPartitionedCall]]
    at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:87)
    at org.tensorflow.Session.run(Session.java:850)
    at org.tensorflow.Session.access$300(Session.java:82)
    at org.tensorflow.Session$Runner.runHelper(Session.java:552)
    at org.tensorflow.Session$Runner.runNoInit(Session.java:499)
    at org.tensorflow.Session$Runner.run(Session.java:495)
    at com.johnsnowlabs.ml.ai.Bart.tag(Bart.scala:169)
    at com.johnsnowlabs.ml.ai.Bart.$anonfun$predict$1(Bart.scala:763)
    at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
    at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
    at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
    at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
    at com.johnsnowlabs.ml.ai.Bart.predict(Bart.scala:749)
    at com.johnsnowlabs.nlp.annotators.seq2seq.BartTransformer.batchAnnotate(BartTransformer.scala:487)
    at com.johnsnowlabs.nlp.HasBatchedAnnotate.$anonfun$batchProcess$1(HasBatchedAnnotate.scala:59)
    at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage4.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:761)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:82)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:208)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:174)
    at org.apache.spark.scheduler.Task.$anonfun$run$5(Task.scala:142)
    at com.databricks.unity.EmptyHandle$.runWithAndClose(UCSHandle.scala:125)
    at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:142)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.Task.run(Task.scala:97)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:904)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1713)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:907)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:761)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more
maziyarpanahi commented 1 year ago

Hi @clabornd

Could you please update your Spark NLP to spark-nlp==4.4.3? We have introduced optimizations for both speed and memory with some code enhancements/bug-fixes:

https://colab.research.google.com/drive/1KucyhiPBc5Eivkiyx94_VFaba8K1bvoV?usp=sharing

clabornd commented 1 year ago

Thanks for the fast response, I tried upgrading to spark-nlp==4.4.3 and the issue persists. I'm running this on Databricks if that's relevant. I tested with runtimes 13.0 and 12.2 and also with a single node machine to no avail, error looks to be the same.

org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 7.0 failed 4 times, most recent failure: Lost task 0.3 in stage 7.0 (TID 27) (10.139.64.6 executor driver): org.tensorflow.exceptions.TFInvalidArgumentException: indices[1024] = 1026 is not in [0, 1026)
     [[{{function_node __inference_encoder_serving_912071}}{{node encoder/embed_positions/embedding_lookup}}]]
    at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:87)

...
maziyarpanahi commented 1 year ago

You are welcome. I cannot reproduce this issue (on any platform). It seems you just updated the PyPI package (which is just empty APIs). The actual logic of the library is in the Maven package. Could you please follow this instruction and make sure the actual spark-nlp Maven dependency is also 4.4.3? https://github.com/JohnSnowLabs/spark-nlp#databricks-cluster

You can also share a screenshot from your Library tab in Cluster configuration in case everything is 4.4.3 and still not working. (mine is 4.4.3 and it works)

clabornd commented 1 year ago

Sorry didn't mention I also updated the Maven package. My Libraries tab looks like:

Screen Shot 2023-05-30 at 9 06 58 AM

Also tried uninstalling everything except the sparknlp PyPI/Maven packages.

Cluster config in case its useful:

{
    "autoscale": {
        "min_workers": 1,
        "max_workers": 8
    },
    "cluster_name": "memory-odbc",
    "spark_version": "13.0.x-scala2.12",
    "spark_conf": {
        "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
        "spark.kryoserializer.buffer.max": "2000M",
        "spark.sql.broadcastTimeout": "40000",
        "spark.databricks.delta.preview.enabled": "true"
    },
    "azure_attributes": {
        "first_on_demand": 1,
        "availability": "ON_DEMAND_AZURE",
        "spot_bid_max_price": -1
    },
    "node_type_id": "Standard_DS13_v2",
    "driver_node_type_id": "Standard_DS13_v2",
    "ssh_public_keys": [],
    "custom_tags": {},
    "spark_env_vars": {},
    "autotermination_minutes": 120,
    "enable_elastic_disk": true,
    "cluster_source": "UI",
    "init_scripts": [],
    "enable_local_disk_encryption": false,
    "runtime_engine": "STANDARD",
    "cluster_id": "0526-205224-3pve3hjz"
}
clabornd commented 1 year ago

Ok I am seeing that error on the Colab notebook as well now, it should pop up if you change display() to an action: pipeline_model.transform(data).show() instead of display(...). (The display() in Databricks is different than IPython.display.display)

The text length that triggers this is 1025 tokens, and lots of BART versions have max context length of 1024, is this not just an issue with max context length?

maziyarpanahi commented 1 year ago

The text length that triggers this is 1025 tokens, and lots of BART versions have max context length of 1024, is this not just an issue with max context length? I see now. This is actually a bug and we should truncate anything longer than 1024 internally since there is no setMaxInputLength to throw an error to users like we do with BERT (the limit is 1024)

I thought we were doing that internally. This is a bug and will be fixed in the next release.