langchain-ai / langchain-extract

🦜⛏️ Did you say you like data?
https://extract.langchain.com/
MIT License
1.03k stars 107 forks source link

combine aws sagemaker endpoint with extraction function? #119

Open MaxS3552284 opened 6 months ago

MaxS3552284 commented 6 months ago

Hello, I tried to combine the langchain extraction functions, (from https://python.langchain.com/docs/use_cases/extraction/quickstart,) with an mistral-7b-instruct endpoint running on aws sagemaker.

I replaced llm = ChatMistralAI(model="mistral-large-latest", temperature=0) as described here: https://python.langchain.com/docs/use_cases/extraction/quickstart

with

llm = SagemakerEndpoint(
    endpoint_name=MISTRAL_ENDPOINT,
    region_name=AWS_REGION,
    content_handler=ContentHandler_mistral7B(),
    callbacks=[StreamingStdOutCallbackHandler()],
    endpoint_kwargs={"CustomAttributes": "accept_eula=true"},
)

class ContentHandler_mistral7B(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:

        input_user = [
                        {
                            "role": "user",
                            "content": prompt,
                        }
        ]

        prompt_formated = format_instructions(input_user)
        payload = {
            "inputs": prompt_formated,
            "parameters": {"max_new_tokens": 256, "do_sample": True, "temperature": 0.1}
        }

        input_str = json.dumps(
            payload,
            ensure_ascii=False,
        )
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        content = response_json[0]["generated_text"]

        return content

def format_instructions(instructions: List[Dict[str, str]]) -> List[str]:
    # from sagemaker notebook
    """Format instructions where conversation roles must alternate user/assistant/user/assistant/..."""
    prompt: List[str] = [] # prompt is supposed to be a list of strings
    for user, answer in zip(instructions[::2], instructions[1::2]):
        prompt.extend(["<s>", "[INST] ", (user["content"]).strip(), " [/INST] ", (answer["content"]).strip(), "</s>"])
    prompt.extend(["<s>", "[INST] ", (instructions[-1]["content"]).strip(), " [/INST] "])
    return "".join(prompt) # join list into single string

which resulted in an error message "LangChainBetaWarning: The function with_structured_output is in beta. It is actively being worked on, so the API may change. warn_beta(

NotImplementedError"

when executing runnable = prompt | llm.with_structured_output(schema=Person)

I suspect that i probably messed up with the input text transformation within the contenthandler, but cant really figure out how

can you help me?