argilla-io / distilabel

Distilabel is a framework for synthetic data and AI feedback for engineers who need fast, reliable and scalable pipelines based on verified research papers.
https://distilabel.argilla.io
Apache License 2.0
1.45k stars 111 forks source link

[IMPLEMENTATION] magpie #740

Closed gabrielmbmb closed 2 months ago

fpreiss commented 3 months ago

I have tried to implement the prompting strategy of the magpie paper using distilabel's ollama integration and noticed, that the current implementation does not allow me to overwrite the chat template. I believe the /api/generate endpoint would need to be wrapped instead of the /api/chat endpoint. I had some success with the following:

TEMPLATE_OVERRIDES: dict[str, str] = {
    # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/#special-tokens-used-with-meta-llama-3
    LLAMA3_8B: "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
}

class OllamaMagpieLLM(OllamaLLM):
    """Magpie compatibility layer for Ollama."""

    async def agenerate(
        self,
        input: StandardInput,
        format: Literal["", "json"] = "",
        # TODO: include relevant options from `Options` in `agenerate` method.
        options: Options | None = None,
        keep_alive: bool | None = None,
    ) -> GenerateOutput:
        """Override of the `OllamaLLM.agenerate` method make Ollama fill the user message.

        The original implementation uses Ollama's chat endpoint instead of the generate endpoint.
        This simplifies implementing multi-turn conversations, but we can't manipulate the prompt template.
        """
        try:
            prompt = input[0]["content"], # needs some work for multi turn support
            completion: dict[str, Any] = await self._aclient.generate(
                prompt=prompt
                model=self.model,
                template=TEMPLATE_OVERRIDES[self.model],
                stream=False,
                format=format,
                options=options,
                keep_alive=keep_alive,
            )
            return [completion["response"]]
        except Exception as e:
            self._logger.warning(
                f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )

Note that as of writing this, the prompt in the generate call has to be a non-empty string, to generate the user instructions as outlined in the paper. Seems to be an issue on ollama/llama.cpp's side.

gabrielmbmb commented 2 months ago

Hi @fpreiss, for now we have implemented Magpie for TransformersLLM, InferenceEndpointsLLM and vLLM. Will work in the next release to add compatibility to the rest of LLMs.