argilla-io / distilabel

Distilabel is a framework for synthetic data and AI feedback for engineers who need fast, reliable and scalable pipelines based on verified research papers.
https://distilabel.argilla.io
Apache License 2.0
1.67k stars 132 forks source link

[BUG] Input data size != output data size when task batch size < batch size of predecessor #972

Open zye1996 opened 2 months ago

zye1996 commented 2 months ago

Describe the bug The behavior is a bit random. When the text generation input size < batch size from the previous step and replica > 1. The final output could missing some samples. This does not happen every time but happens frequently. I suspect it has something to do with batch/multi-processing scheduling.

In the following case, default LoadDataFromDicts batch size is 50, and batch_size of Text generation is set lower than that, in this case 17. The total input sample number is 60, however, when saving the data to disk, only 52 samples are saved. When setting Text generation batch size greater than 50, all samples can be successfully saved.

To Reproduce Code to reproduce

# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from distilabel.llms import MistralLLM, AnthropicLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts, StepResources
from distilabel.steps.tasks import TextGeneration

resources = StepResources(replicas=8)

with Pipeline(
    name="Knowledge-Graphs",
    description=(
        "Generate knowledge graphs to answer questions, this type of dataset can be used to "
        "steer a model to answer questions with a knowledge graph."
    ),
) as pipeline:
    sample_questions = [
        "Teach me about quantum mechanics",
        "Who is who in The Simpsons family?",
        "Tell me about the evolution of programming languages",
    ] * 20

    load_dataset = LoadDataFromDicts(
        name="load_instructions",
        data=[
            {
                "system_prompt": "You are a knowledge graph expert generator. Help me understand by describing everything as a detailed knowledge graph.",
                "instruction": f"{question}",
            }
            for question in sample_questions
        ],

    )

    text_generation = TextGeneration(
        name="knowledge_graph_generation",
        llm=AnthropicLLM(
            model="claude-3-5-sonnet-20240620",
            generation_kwargs={"max_tokens": 4096,
                               "temperature": 0.5}

            ),
        input_batch_size=17,
        output_mappings={"model_name": "generation_model"},
        resources=resources
    )
    load_dataset >> text_generation

if __name__ == "__main__":

    from pathlib import Path

    distiset = pipeline.run(
        parameters={
            text_generation.name: {
                "llm": {"generation_kwargs": {"max_tokens": 2048}}
            }
        },
        use_cache=False,
    )
    distiset.save_to_disk(Path("test_out"),
                          save_card=False,
                          save_pipeline_log=False,
                          save_pipeline_config=False)

Expected behaviour A clear and concise description of what you expected to happen.

Screenshots If applicable, add screenshots to help explain your problem. Screenshot 2024-09-11 at 9 14 03 PM

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

zye1996 commented 2 months ago

looks like some batches are processed twice, more like a multi-processing issue.

Screenshot 2024-09-11 at 9 45 38 PM

gabrielmbmb commented 2 months ago

Thanks for reporting @zye1996. I'll take a look.

zye1996 commented 2 months ago

@gabrielmbmb should this line return False? Otherwise, if the last batch arrives earlier than the previous batches, data are forced to be sent to the next step and some data could be missing if they cannot be created for another batch. Let me know if a PR is needed

https://github.com/argilla-io/distilabel/blob/e67864e1b65a2633205bb3966cde41e16373862d/src/distilabel/pipeline/batch_manager.py#L514

thesven commented 1 month ago

I've also started noticing this on a pipline I've created. Using an input_batch_size of one on some text generation tasks led to the final data set size only containing one row for each processed batch of the previous output - which had been created using a step mixin and could not have an inforced batch size. @gabrielmbmb I have some code I can share that exhibits the issue that I can share as well.