databricks-demos / dbdemos

Demos to implement your Databricks Lakehouse
Other
273 stars 87 forks source link

Make batch summarizer pipeline in llm-dolly-chatbot run on GPU #26

Closed tlortz closed 1 year ago

tlortz commented 1 year ago

Existing code in 02-Data-preparation notebook of llm-dolly-chatbot demo has two issues:

@pandas_udf("string")
def summarize(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    # Load the model for summarization
    summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
    def summarize_txt(text):
      if len(text) > 400:
        return summarizer(text)
      return text

    for serie in iterator:
        # get a summary for each row
        yield serie.apply(summarize_txt)
  1. It doesn't utilize the GPU
  2. The summarization pipeline returns values of type List[Dict], which fails in the write operation

Recommend replacing with

@pandas_udf("string")
def summarize(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    # Load the model for summarization
    summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=0)
    def summarize_txt(text):
      if len(text) > 400:
        return summarizer(text)[0]['summary_text']
      return text

    for serie in iterator:
        # get a summary for each row
        yield serie.apply(summarize_txt)

This results in GPU utilization around 40% - probably low because we're using a batch size of 1, but definitely faster than using no GPU. The entire job runs in 16 minutes on the g5.4xlarge cluster created by dbdemos

QuentinAmbard commented 1 year ago

Hi, @tlortz ,thanks, that was fixed ~yesterday in the latest release, can you have a look here & let us know if this works better? https://www.dbdemos.ai/demo-notebooks.html?demoName=llm-dolly-chatbot

Also I tried to send a list of text to the summarizer but I don't see any performance improvement. Let me know if you have tips to increase GPU utilization!

QuentinAmbard commented 1 year ago

closing this issue as we added the fix in the last release - let me know if we can improve it further!