irthomasthomas / undecidability

13 stars 2 forks source link

LoRAX + Outlines: Better JSON Extraction with Structured Generation and LoRA - Predibase - Predibase #709

Open irthomasthomas opened 8 months ago

irthomasthomas commented 8 months ago

LoRAX + Outlines: Better JSON Extraction with Structured Generation and LoRA - Predibase - Predibase

DESCRIPTION:
LoRAX + Outlines: Better JSON Extraction with Structured Generation and LoRA
March 3, 2024 · 6 Min Read

Jeffrey Tang
Travis Addair

LoRAX is an open source inference server for large language models that supports serving thousands of fine-tuned adapters on top of a shared base model running on a single GPU. The v0.8 release of LoRAX introduces native support for the popular Outlines library, enabling LoRAX to generated output that consistently follows a given schema.

Large language models have proven to be quite good at generating human-readable text, but sometimes you need the output of your LLM to be consumed not just by humans, but by other automated systems. JSON is a popular format for structuring data that can be further constrained by a specific schema. This raises the question: how can we constrain our LLM to only output JSON that adheres to a chosen schema?

In this blog, we’ll introduce two popular methods for extracting and generating JSON using LLMs: structured generation and fine-tuning. We’ll show how you can specify an Outlines-compatible schema in your request to generate output that adheres to a specific format. We’ll demonstrate that while the schema can enforce the right format, it can’t guarantee that the right content is used to populate the properties of the JSON object.

By fine-tuning the LLM with just a few lines of code, we’ll create a LoRA adapter for extracting the desired JSON properties from our input. We’ll see how fine-tuning gets the right content, but can’t guarantee the correct format for the output. Finally, we’ll combine the two approaches at once with LoRAX to get the best of both worlds: output that follows our desired schema and contains all the correct content for the JSON properties.

What is JSON mode, and how do I use it?
JSON mode in LoRAX is a form of structured generation, also sometimes called constrained decoding. It’s a way of forcing an LLM to generate text that conforms to a JSON schema. To use JSON mode, first start LoRAX:

model=mistralai/Mistral-7B-v0.1
volume=\$PWD/data  # share a volume with the container as a weight cache

docker run --gpus all --shm-size 1g -p 8080:80 -v \$volume:/data \\
    ghcr.io/predibase/lorax:latest --model-id \$model

Then, define your desired schema (this example uses a Pydantic model, but a manually constructed schema works as well):

from pydantic import BaseModel, constr

class Person(BaseModel):
    name: constr(max_length=10)
    age: int

schema = Person.schema()

Finally, pass the schema to LoRAX as part of the new response_format parameter when prompting:

from lorax import Client

client = Client(\"http://127.0.0.1:8080\")

client.generate(
    \"Create a person description for me\", 
    response_format={\"type\": \"json_object\", \"schema\": schema}
)

Case Study: Structured Generation vs. Fine-tuning
Let’s take a closer look at how structured generation works in practice and how you can use it in conjunction with fine-tuning to get better, more reliable LLM outputs while avoiding common pitfalls.

Suppose we want to use Mistral-7B to perform a Named Entity Recognition (NER) task. For that, we’ll use a version of the CoNLL++ dataset that we’ve modified for LLM use. This dataset contains rows like:

input output
EU rejects German call to boycott British lamb . {\"person\": [], \"organization\": [\"EU\"], \"location\": [], \"miscellaneous\": [\"German\", \"British\"]}
Peter Blackburn {\"person\": [\"Peter Blackburn\"], \"organization\": [], \"location\": [], \"miscellaneous\": []}

The output that we want the LLM to generate is a JSON object categorizing entities in the input as a person, organization, location, or miscellaneous.

First, let’s see how well the base Mistral-7B model performs.

prompt_template = \"\"\"
Your task is a Named Entity Recognition (NER) task. Predict the category of
each entity, then place the entity into the list associated with the 
category in an output JSON payload. Below is an example:

Input: EU rejects German call to boycott British lamb . Output: {{\"person\":
[], \"organization\": [\"EU\"], \"location\": [], \"miscellaneous\": [\"German\",
\"British\"]}}

Now, complete the task.

Input: {input} Output:\"\"\"

# Base Mistral-7B
client.generate(
    prompt_template.format(input=\"Only France and Britain backed Fischler 's proposal .\"),  
    max_new_tokens=128,
)

The base model tries to generate something reasonable, but quickly goes off the rails:

{\"person\":
[], \"organization\": [\"France\", \"Britain\"], \"location\": [], \"miscellaneous\":
[\"Fischler\"]}

Adding Outlines for Structured Generation
Now let’s add JSON schema enforcement to the mix. First, we’ll define a suitable schema:

class Output(BaseModel):
    person: List[str]
    organization: List[str]
    location: List[str]
    miscellaneous: List[str]

schema = Output.schema()

Now we can easily apply this schema to our test prompt:

# Base Mistral-7B + JSON schema
client.generate(
    prompt_template.format(input=\"Only France and Britain backed Fischler 's proposal .\"),
    response_format={
        \"type\": \"json_object\",
        \"schema\": schema,
    },
    max_new_tokens=128,
)

Which yields:

{
    \"person\": [],
    \"organization\": [\"France\", \"Britain\"],
    \"location\": [],
    \"miscellaneous\": [\"Fischler\"]
}

Not bad! We get a JSON object that follows our schema. Unfortunately, the entities aren’t classified accurately. Fischler should be a person, and France and Britain should be locations. Applying a schema has improved the model’s ability to produce well-structured outputs, but not its ability to produce quality outputs within that structure.

Fine-tuning a LoRA Adapter
Our goal is to steer the LLM to be better tailored to the task of JSON named-entity extraction. The most popular technique for such task-specialization is fine-tuning. Using a platform like Predibase, we can fine-tune a small task-specific LoRA adapter with just a few lines of code:

from predibase import Predibase

pb = Predibase()

# Choose the base model and dataset to fine-tune on
mistral = pb.models.from_hf(\"mistralai/Mistral-7B-v0.1\")
conllpp = pb.datasets.from_file(\"/mydata/conllpp_train.csv\")

# Train with our chosen prompt template from above
my_adapter = pb.adapters.create(
    {
        \"base_model\": mistral,
        \"dataset\": conllpp,
        \"prompt\": prompt_template,
        \"target\": \"output\"
    }
}

Because fine-tuning a model can take some time, we’ve gone ahead and run this step and uploaded the resulting artifacts to HuggingFace here: predibase/conllpp.

If you’ve used an LLM model server in the past, you might expect we’d need to redeploy our server with the newly fine-tuned model adapter in order to evaluate it. But thanks to LoRAX’s dynamic adapter loading, we can try it out immediately. Returning to the LoRAX client from before, prompting with this new adapter is as easy as providing the HuggingFace model name as the adapter_id:

# Mistral-7B + Adapter
client.generate(
    prompt_template.format(input=\"Only France and Britain backed Fischler 's proposal .\"),
    adapter_id=\"predibase/conllpp\",  
    max_new_tokens=128,
)

Yielding:

{
    \"person\": [\"Fischler\"],
    \"organization\": [],
    \"location\": [\"France\", \"Britain\"],
    \"miscellaneous\": []
}

Combining LoRA + Structured Generation
For our previous example input, the fine-tuned LoRA adapter generated valid JSON and tagged the entity types correctly, but will it always get this correct?

Let’s try a different example:

# Mistral-7B + Adapter
client.generate(
    prompt_template.format(input=\"I 've also got a contract to play for New South Wales in the Super 12 next year . \"),
    adapter_id=\"predibase/conllpp\",  
    max_new_tokens=128,
)

Which yields:

{
    \"person\": [],
    \"organization\": [\"New South Wales\"],
    \"location\": [], 
    \"miscellaneous\": [],
    \"Super 12\"]
}

In this case, the model tagged most of the entities correctly, but the payload is not valid JSON because of the trailing extra \"Super 12\". The fine-tuned adapter extracted the right entities, but didn’t get the JSON format correct.

Now what if we combine our fine-tuned adapter with our response schema above?

client.generate(
    prompt_template.format(input=\"I 've also got a contract to play for New South Wales in the Super 12 next year . \"),
    adapter_id=\"predibase/conllpp\",
    response_format={
        \"type\": \"json_object\",
        \"schema\": schema,
    },
    max_new_tokens=128,
)

Great! Now the model is once again producing valid JSON with tagged entities:

{
    \"person\": [],
    \"organization\": [\"New South Wales\"],
    \"location\": [], 
    \"miscellaneous\": [\"Super 12\"]
}

Benchmarks
So far, it looks like both JSON schema enforcement and model fine-tuning are useful techniques to help address our NER use case. But we’ve only looked at one example input. To better quantify the performance of these approaches, we’ll benchmark them against a held-out test set of ~3.5k examples.

The metric we’re interested is multiset Jaccard similarity, which means a score closer to 1.0 is better, and closer to 0.0 is worse. We’ll also collect some statistics on the generated outputs: how many were parseable as JSON, how many had formatting errors, and how many fully met our desired structure.

First let’s measure the percentage of responses from the model that conform to our desired JSON schema. Unsurprisingly, when Outlines is used nearly every response follows the schema, demonstrating that the structured generation process is working as intended. Interestingly, the fine-tuned model on its own did decently well at following the schema, even though there is no enforcement mechanism to guarantee it does so other than what it learned during fine-tuning. Nevertheless, it’s clear from these results that structured generation is a more reliable mechanism to enforce the schema than fine-tuning.

The percent of responses in the correct JSON schema format

Now let’s measure how closely the content of the properties in the JSON response returned by the model match the ground truth. This time, even though not as many payloads come back in the correct format, we see that fine-tuning does significantly better than structured generation alone at populating the correct content in the response. But most importantly: it is the combination of fine-tuning and structured generation together that achieve the best overall performance, both in number of correctly formatted payloads and in overall similarity.

Multiset Jaccard Similarity

Looking at these results, we see a clear trajectory emerge. The base model, unsurprisingly, has the worst performance score and is unable to generate valid JSON in most cases.

Applying JSON schema constraints to the model results in far more well-formed payloads, but a relatively low similarity score of .50 confirms our suspicions that the model isn’t doing well on the NER task despite the formatting improvements.

Next, we see that fine-tuning significantly improves on both the base model and the schema-only approach. A high similarity score indicates good performance on the NER task, but a notable percentage of malformed outputs is still present.

Finally, by combining a fine-tuned adapter and schema enforcement, we get the best performance yet. 99.9% of generated outputs were well-formed, and the overall similarity score easily surpasses the performance of the adapter alone.

Avoiding Pitfalls
Before wrapping up this case study, let’s cover how to avoid common pitfalls when using structured generation.

First, it’s important to ensure that your max token limit is high enough to accommodate the JSON object you want to generate.

Structured generation works by altering the probability distribution for candidate tokens so that only tokens which would produce valid JSON can be chosen. However, it does not “plan ahead” based on the token limit. If this limit is too low, you’ll end up with half an object!

Second, structured generation only prevents invalid tokens from being generated. It won’t force your model to select the right tokens.

As we saw in the base model + schema example above, schema enforcement gave us a valid JSON object, but didn’t classify the entities correctly.

Finally, if your schema and model clash, performance can actually get worse.

As a concrete example, consider the adapter we used in our example above. It was fine-tuned to strongly prefer generating a JSON output with fields in a specific order: “person”, “organization”, “location”, “miscellaneous”.

As it turns out, when Outlines internally converts a JSON schema into a regular expression, and then into a state machine, it also imposes an ordering constraint on object fields.

While putting together this case study, we initially failed to take this into account. Our original schema draft simply listed the desired fields in alphabetical order: “location”, “miscellaneous”, “organization”, “person”.

Because of this clash, the schema kept forcing the model to select tokens with relatively low probability when generating field names, sending the model down the wrong path for future forward passes. Not only did this degrade the quality of entity extraction, it also sometimes led to a loop where the model repeatedly generated whitespace tokens - hundreds of them - until eventually hitting the token limit.

This performance hit was borne out by benchmarks:

finetuned_raw_score              0.709870 # Mistral-7B + adapter
finetuned_constrained_old_score  0.649936 # Mistral-7B + adapter + bad schema
finetuned_constrained_score      0.804054 # Mistral-7B + adapter + good schema

In other words, JSON mode is a useful tool for keeping your model on track and guarding against undesirable output formats, but it’s important to understand how it interacts with your fine-tuned model’s behavior and ensure they work together, not against each other.

What’s next? Fine-tune and serve your own LLMs.
In this case study, we showed you how to generate accurate and structurally correct JSON by fine-tuning Mistral on Predibase and serving with LoRAX and the Outlines library. Here are some additional references to help you get started fine-tuning and serving your own LLMs:

Suggested labels

irthomasthomas commented 8 months ago

Related content

638 - Similarity score: 0.88

645 - Similarity score: 0.88

494 - Similarity score: 0.87

505 - Similarity score: 0.86

660 - Similarity score: 0.86

515 - Similarity score: 0.86