JohnSnowLabs / spark-nlp

State of the Art Natural Language Processing
https://sparknlp.org/
Apache License 2.0
3.8k stars 707 forks source link

java.lang.IllegalArgumentException: No Operation named [missing_decoder_input_ids_init] in the Graph when using pretrained model #14343

Closed hortiprajwal closed 1 month ago

hortiprajwal commented 2 months ago

Is there an existing issue for this?

Who can help?

No response

What are you working on?

I am trying to run spark-nlp in databricks for summarization of text. I am using the bart_large_cnn model for the summarization.

Current Behavior

Py4JJavaError: An error occurred while calling o584.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 9.0 failed 4 times, most recent failure: Lost task 2.3 in stage 9.0 (TID 412) (172.16.2.60 executor 0): java.lang.IllegalArgumentException: No Operation named [missing_decoder_input_ids_init] in the Graph
    at org.tensorflow.Graph.outputOrThrow(Graph.java:211)
    at org.tensorflow.Session$Runner.feed(Session.java:248)
    at com.johnsnowlabs.ml.ai.Bart.getModelOutput(Bart.scala:414)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.$anonfun$beamSearch$7(Generate.scala:225)
    at scala.util.control.Breaks.breakable(Breaks.scala:42)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch(Generate.scala:213)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch$(Generate.scala:182)
    at com.johnsnowlabs.ml.ai.Bart.beamSearch(Bart.scala:40)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate(Generate.scala:151)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate$(Generate.scala:85)
    at com.johnsnowlabs.ml.ai.Bart.generate(Bart.scala:40)
    at com.johnsnowlabs.ml.ai.Bart.tag(Bart.scala:280)
    at com.johnsnowlabs.ml.ai.Bart.$anonfun$predict$1(Bart.scala:124)
    at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
    at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
    at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
    at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
    at com.johnsnowlabs.ml.ai.Bart.predict(Bart.scala:109)
    at com.johnsnowlabs.nlp.annotators.seq2seq.BartTransformer.batchAnnotate(BartTransformer.scala:324)
    at com.johnsnowlabs.nlp.HasBatchedAnnotate.processBatchRows(HasBatchedAnnotate.scala:65)
    at com.johnsnowlabs.nlp.HasBatchedAnnotate.$anonfun$batchProcess$1(HasBatchedAnnotate.scala:53)
    at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$5(UnsafeRowBatchUtils.scala:88)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$3(UnsafeRowBatchUtils.scala:88)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$1(UnsafeRowBatchUtils.scala:68)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:62)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$2(Collector.scala:197)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:82)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:82)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:62)
    at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:196)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:181)
    at org.apache.spark.scheduler.Task.$anonfun$run$5(Task.scala:146)
    at com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:45)
    at com.databricks.unity.HandleImpl.runWith(UCSHandle.scala:103)
    at com.databricks.unity.HandleImpl.$anonfun$runWithAndClose$1(UCSHandle.scala:108)
    at scala.util.Using$.resource(Using.scala:269)
    at com.databricks.unity.HandleImpl.runWithAndClose(UCSHandle.scala:107)
    at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:146)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.Task.run(Task.scala:99)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$8(Executor.scala:900)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1709)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:903)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:798)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:750)
/

Expected Behavior

The code should print the summary and run without errors.

Steps To Reproduce

from pyspark.sql import DataFrame
from pyspark.sql.functions import col, length, lit, isnull, when
from sparknlp.base import Pipeline, DocumentAssembler
from sparknlp.annotator import BartTransformer

from abc import ABCMeta, abstractmethod
from pyspark.sql import SparkSession, DataFrame
from sparknlp.base import PipelineModel

data = spark.createDataFrame(
    [
        (
            """Landing on another planetary body has been one of the biggest challenges and engineers across the world have been working to develop innovating technology to ensure a smooth touchdown, be it the Moon or Mars.

About 27 years ago, as a Nasa-led spacecraft was hurtling down through the thin Martian atmosphere, all eyes were glaring on the deep space network, which was looking for signs of success. Onboard was a revolutionary method that was all about a bouncy landing.

In July 1997, the spacecraft made history by successfully bouncing 15 times on the Martian surface before coming to rest, demonstrating a new and effective method for landing on Mars.

The Pathfinder mission, consisting of a lander and the small Sojourner rover, employed a groundbreaking approach to cushion its impact.

Pathfinder was enveloped in a cluster of airbags, allowing it to bounce safely to a stop on the planet's surface.

This airbag-mediated landing technique proved to be a game-changer for Mars exploration. It offered a more cost-effective and lightweight alternative to the previous rocket-powered soft landing methods used by earlier missions like Viking.

The success of Pathfinder's landing system paved the way for future Mars missions, including the Mars Exploration Rovers Spirit and Opportunity, which also utilized airbag technology for their landings in 2004.

The bouncing landing not only protected the delicate scientific instruments on board but also allowed for a [wider range of potential landing sites](https://www.indiatoday.in/science/story/mars-math-and-mohan-meet-the-indian-scientist-who-landed-rover-on-red-planet-2512370-2024-03-08). This flexibility was crucial for expanding the areas of Mars that could be explored by robotic missions.

Pathfinder's success demonstrated the viability of NASA's "faster, better, cheaper" approach to planetary exploration. T[he mission's innovative landing technique](https://www.indiatoday.in/science/story/mangalayaan-2-isro-mars-mission-sky-crane-helicopter-rover-2538652-2024-05-13), combined with its scientific achievements, helped reinvigorate public interest in Mars exploration and set the stage for future missions to the Red Planet.

The legacy of Pathfinder's bouncing landing continues to influence Mars exploration strategies, inspiring engineers to develop new and creative solutions for the challenges of interplanetary travel and exploration.""",
        )
    ]
).toDF("content")

class SparkNLPModel(metaclass=ABCMeta):
    @staticmethod
    @abstractmethod
    def build_pipeline(spark: SparkSession, input_col: str) -> PipelineModel:
        pass

    @abstractmethod
    def process(self, spark: SparkSession, data: DataFrame, input_col: str = "content") -> DataFrame:
        pass

class SummarizationModel(SparkNLPModel):
    @staticmethod
    def build_pipeline(input_col: str) -> Pipeline:
        document_assembler = DocumentAssembler().setInputCol(input_col).setOutputCol("document")

        bart_transformer = (
            BartTransformer.pretrained(name="bart_large_cnn", lang="en")
            .setTask("summarize:")
            .setInputCols(["document"])
            .setMaxOutputLength(200)
            .setOutputCol("summaries")
        )

        pipeline = Pipeline().setStages([document_assembler, bart_transformer])
        return pipeline

    def process(self, data: DataFrame, input_col: str) -> DataFrame:
        data = data.withColumn("content_length", when(isnull(col(input_col)), 0).otherwise(length(col(input_col))))

        # generate summary for content length > 200
        long_content_df = data.filter(col("content_length") > 200)
        short_content_df = data.filter(col("content_length") <= 200)
        short_content_df = short_content_df.withColumn("summary", lit(""))

        if long_content_df.count() > 0:
            pipeline = self.build_pipeline(input_col=input_col)
            summarized_df = pipeline.fit(data).transform(long_content_df)
            summarized_df = summarized_df.withColumn("summary", col("summaries.result")[0])
            summarized_df = summarized_df.drop("document", "summaries")
            data = summarized_df.union(short_content_df)
        else:
            data = short_content_df

        return data.drop("content_length")

news_articles_df = SummarizationModel().process(data=data, input_col="content")
news_articles_df.select('summary').show(truncate=False)

Spark NLP version and Apache Spark

Spark NLP version: 5.4.0
Spark version: 3.4.1

Databricks Runtime Version: 13.3 LTS (includes Apache Spark 3.4.1, Scala 2.12)

Type of Spark Application

Python Application

Java Version

No response

Java Home Directory

No response

Setup and installation

Installed the dependency from Maven and PyPI under libraries in databricks.

com.johnsnowlabs.nlp:spark-nlp_2.12:5.4.0 
spark-nlp==5.4.0

These spark config are added in the cluster:

spark.serializer org.apache.spark.serializer.KryoSerializer
spark.kryoserializer.buffer.max 2000M

Operating System and Version

No response

Link to your project (if available)

No response

Additional Information

ahmedlone127 commented 1 month ago

Hello @hortiprajwal , I have reuploaded this model and your issue should be resolved with the latest version of spark nlp (5.4.1)

hortiprajwal commented 1 month ago

Hello @hortiprajwal , I have reuploaded this model and your issue should be resolved with the latest version of spark nlp (5.4.1)

Hello @ahmedlone127, I tried this with the spark-nlp (5.4.1), but the issue still persists.

java.lang.IllegalArgumentException: No Operation named [missing_decoder_input_ids_init] in the Graph
    at org.tensorflow.Graph.outputOrThrow(Graph.java:211)
    at org.tensorflow.Session$Runner.feed(Session.java:248)
    at com.johnsnowlabs.ml.ai.Bart.getModelOutput(Bart.scala:414)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.$anonfun$beamSearch$7(Generate.scala:228)
    at scala.util.control.Breaks.breakable(Breaks.scala:42)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch(Generate.scala:216)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch$(Generate.scala:184)
    at com.johnsnowlabs.ml.ai.Bart.beamSearch(Bart.scala:40)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate(Generate.scala:153)
    at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate$(Generate.scala:85)
    at com.johnsnowlabs.ml.ai.Bart.generate(Bart.scala:40)
    at com.johnsnowlabs.ml.ai.Bart.tag(Bart.scala:280)
    at com.johnsnowlabs.ml.ai.Bart.$anonfun$predict$1(Bart.scala:124)