iterative / datachain

DataChain 🔗 AI-dataframe to enrich, transform and analyze data from cloud storages for ML training and LLM apps
https://datachain.dvc.ai
Apache License 2.0
688 stars 35 forks source link

Aggregator and generator inputs should be Iterators, not lists #173

Open volkfox opened 1 month ago

volkfox commented 1 month ago

Description

Here is a sample generator from LLM tutorial:

from typing import List

class Dialog(BaseModel):
    id: int
    text: str

def text_block(id: List[int], sender: List[str], text: List[str]) -> Dialog:
    columns = zip(text, sender)
    conversation = ""
    for text, sender in columns:
       conversation = "\n ".join([conversation,f"{sender}: {text}"])
    yield Dialog(id=id[0], text=conversation) 

chain = DataChain.from_csv("gs://datachain-demo/chatbot-csv/").agg(text_block, output={"dialog": Dialog}, partition_by="id").save()

This syntax has a number of issues:

  1. Input Column names are implictly made into list names. This is awkward because argument "sender" is a list that would be better named "senders".

  2. Passing lists from SQL limits out-of-memory operations

  3. The aggregation key when passed as a parameter does not have to be a list because it is identical in every record

Here is a proposed updated signature:

def text_block(id: int, sender: Iterator[str], text: Iterator[str]): -> dict[str, str]
    columns = zip(text, sender)
    conversation = ""
    for text, sender in columns:
       conversation = "\n ".join([conversation,f"{sender}: {text}"])
    yield {"id": id[0], "conversation": conversation} 

chain = DataChain.from_csv('gs://datachain-demo/chatbot-csv/').agg(text_block, partition_by='id').save()
dmpetrov commented 1 month ago

That's great idea!

It seems you are also proposed the idea of returning dict and use keys of the dict as return signals. I recommend creating a separate issue for that - these two are not related to each other and dict as an output might be challenging issue since we have a built-in dict already.

Without this, the API should look the one below. @volkfox please correct me if I'm missing anything.

def text_block(id: int, sender: Iterator[str], text: Iterator[str]) -> tuple[int, str]:
    columns = zip(text, sender)
    conversation = ""
    for text, sender in columns:
       conversation = "\n ".join([conversation,f"{sender}: {text}"])
    yield id, conversation

chain = (
    DataChain.from_csv('gs://datachain-demo/chatbot-csv/')
    .agg(res=text_block, partition_by='id', output={"id": int, "conversation": str} )
    .save()
)