dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
8.54k stars 431 forks source link

Bad performance in generating string properties given a JSON schema #994

Closed liqul closed 3 months ago

liqul commented 3 months ago

Describe the issue as clearly as possible:

I'm not sure if I missed anything. Basically, I want to extract information from a provided paragraph based on a JSON schema. When the schema contains properties in string type, the output values are wrong like ", " or ": ". I looked into the implementation in json_schema.py and can see the default regex for string is defined by

STRING_INNER = r'([^"\\\x00-\x1f\x7f-\x9f]|\\\\)
STRING = f'"{STRING_INNER}*"'

If I change the definintion to something simpler like r'[\w ]', the performance seems getting better, but I didn't tested comprehensively. I'm not sure if you have tested this scenario before and what might be causing this issue.

Steps/code to reproduce the bug:

def get_model_outlines(model_id: str):
    print(f"Loading model {model_id}")
    from outlines import models
    import torch
    model = models.transformers(
        model_id, 
        model_kwargs={
            "torch_dtype": torch.bfloat16, 
            "device_map":"auto"
        }
    )
    print(f"Model {model_id} loaded")
    return model

@timeout_decorator.timeout(120, timeout_exception=TimeoutError)  
def run_outlines(
    prompt: str,
    schema: dict,
    model: any,
    max_new_tokens: int = 512,
    top_p: float = 1.0,
    temperature: float = 0.0001
):
    from outlines import generate, samplers
    import json

    parsed_schema = json.dumps(schema, indent=2)

    sampler = samplers.multinomial(1, top_p=top_p, temperature=temperature)
    generator = generate.json(model, parsed_schema, sampler=sampler, whitespace_pattern="")
    response = generator(prompt, max_tokens=max_new_tokens)
    if isinstance(response, list):
        response = response[0]
    return json.dumps(response, indent=2)

loaded = get_model_outlines("meta-llama/Meta-Llama-3-8B-Instruct")
# loaded = get_model_outlines("mistralai/Mistral-7B-Instruct-v0.2")

prompt = """The description of Los Angeles International Airport is: Los Angeles International Airport (LAX) is one of the busiest airports in the world, located in Los Angeles, California, USA. The IATA code for the airport is LAX, and the ICAO code is KLAX. It is situated in the Pacific Time Zone (America/Los_Angeles). The airport's coordinates are approximately 33.9416° N latitude and 118.4085° W longitude. LAX serves as a major gateway for international and domestic flights, connecting millions of passengers to various destinations around the globe.

Please provide information about Los Angeles International Airport in the following format: {'$schema': 'http://json-schema.org/draft-07/schema#', 'type': 'object', 'title': 'Airport', 'required': ['name', 'IATA', 'ICAO', 'location', 'timezone'], 'properties': {'name': {'type': 'string', 'description': 'The name of the airport.'}, 'IATA': {'type': 'string', 'pattern': '^[A-Z]{3}$', 'description': 'The IATA code of the airport, a 3-letter code.'}, 'ICAO': {'type': 'string', 'pattern': '^[A-Z]{4}$', 'description': 'The ICAO code of the airport, a 4-letter code.'}, 'location': {'type': 'object', 'required': ['city', 'country', 'coordinates'], 'properties': {'city': {'type': 'string', 'description': 'The city where the airport is located.'}, 'country': {'type': 'string', 'description': 'The country where the airport is located.'}, 'coordinates': {'type': 'object', 'required': ['latitude', 'longitude'], 'properties': {'latitude': {'type': 'number', 'description': "The latitude of the airport's location."}, 'longitude': {'type': 'number', 'description': "The longitude of the airport's location."}}}}}, 'timezone': {'type': 'string', 'description': "The timezone of the airport, in IANA timezone format (e.g., 'America/New_York')."}}}"""

schema = {'$schema': 'http://json-schema.org/draft-07/schema#', 'type': 'object', 'title': 'Airport', 'required': ['name', 'IATA', 'ICAO', 'location', 'timezone'], 'properties': {'name': {'type': 'string', 'description': 'The name of the airport.'}, 'IATA': {'type': 'string', 'pattern': '^[A-Z]{3}$', 'description': 'The IATA code of the airport, a 3-letter code.'}, 'ICAO': {'type': 'string', 'pattern': '^[A-Z]{4}$', 'description': 'The ICAO code of the airport, a 4-letter code.'}, 'location': {'type': 'object', 'required': ['city', 'country', 'coordinates'], 'properties': {'city': {'type': 'string', 'description': 'The city where the airport is located.'}, 'country': {'type': 'string', 'description': 'The country where the airport is located.'}, 'coordinates': {'type': 'object', 'required': ['latitude', 'longitude'], 'properties': {'latitude': {'type': 'number', 'description': "The latitude of the airport's location."}, 'longitude': {'type': 'number', 'description': "The longitude of the airport's location."}}}}}, 'timezone': {'type': 'string', 'description': "The timezone of the airport, in IANA timezone format (e.g., 'America/New_York')."}}}

print(run_outlines(prompt, schema, loaded))

Expected result:

{
  "name": ", ",
  "IATA": "LAX",
  "ICAO": "KLAX",
  "location": {
    "city": "Los Angeles",
    "country": "USA",
    "coordinates": {
      "latitude": 33.9416,
      "longitude": -118.4085
    }
  },
  "timezone": "America/Los_Angeles"
}

Error message:

{
  "name": "Los Angeles International Airport",
  "IATA": "LAX",
  "ICAO": "KLAX",
  "location": {
    "city": "Los Angeles",
    "country": "USA",
    "coordinates": {
      "latitude": 33.9416,
      "longitude": -118.4085
    }
  },
  "timezone": "America/Los_Angeles"
}

Outlines/Python version information:

Version information 0.0.45

Context for the issue:

The provided example is not extremely bad. Sometimes, many properties of string type are ", ", like shown below:

{
  "name": ", ",
  "IATA": "HND",
  "ICAO": "RJTT",
  "location": {
    "city": ",",
    "country": "Japan",
    "coordinates": {
      "latitude": 35.5494,
      "longitude": 139.7798
    }
  },
  "timezone": "Asia/Tokyo"
}

{
  "name": ", ",
  "address": {
    "street": ", ",
    "city": ", ",
    "state": ", ",
    "postalCode": ", "
  },
  "type": "Teaching",
  "numberOfBeds": 1265
}
lapp0 commented 3 months ago

The pattern in main changed yesterday:

# allow `\"`, `\\`, or any character which isn't a control sequence
STRING_INNER = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])'
STRING = f'"{STRING_INNER}*"'

For a valid string inner, compliant with json, we need to ensure a character is either

Would be great if we could simplify it though, however [\w ] doesn't work, e.g.

>>> print(re.match(r'[\w ]', "!"))
None

You can test your new STRING_INNER here https://github.com/outlines-dev/outlines/blob/main/tests/fsm/test_json_schema.py#L117-L132

liqul commented 3 months ago

Thanks for the quick response. I know that my simplification has many limitations. But the issue I observed is that the correctness of string properties in the generated JSON object are constantly low, compared to other types, even the value is quite common sense. I was not suggesting changing the definitions.

lapp0 commented 3 months ago

I've run into a similar issue.

I think there are two solutions here:

Let me know if this makes sense to you or if you have any other ideas

liqul commented 3 months ago

I'm comparing different libraries for constrained generation. So, I'm using the same model and the same prompt+schema with different approaches. That's why I found that outlines achieves a relatively low performance for string types compared to other libraries, and I'm sure that this is not solely a problem of the model.

The interesting part is that different libraries, although based on a similar underlying logits processing technique, have different implementations of the regex for the string type. That's why I believe recommending a different default regex could improve the performance.

lapp0 commented 3 months ago

That's why I found that outlines achieves a relatively low performance for string types compared to other libraries, and I'm sure that this is not solely a problem of the model.

That's very interesting. Could you link the libraries which perform best in your experiments? Or is it simply [\w ] as you described?

liqul commented 3 months ago

You can take a look at this one https://github.com/noamgat/lm-format-enforcer

lapp0 commented 3 months ago

Looking at lm-format-enforcer, it seems they allow any token to be produced other than quote. I'm not sure what is making outlines perform worse, but experimenting with better string patterns and pydantic.constr is definitely worth doing.

lapp0 commented 3 months ago

If we guarantee all strings start with an alphanumeric for the first character (but don't constrain it otherwise thereafter) the output is much better.

_any_alphanum = r'[^\W_]'
_any_string_inner = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])'
smart_string = f"({_any_alphanum}{_any_string_inner}*)?"

full pattern '([^\\W_]([^"\\\\\\x00-\\x1F\\x7F-\\x9F]|\\\\["\\\\])*)?'

e.g. {'name': {'type': 'string', 'description': 'The name of the airport.', 'pattern': '([^\\W_]([^"\\\\\\x00-\\x1F\\x7F-\\x9F]|\\\\["\\\\])*)?'}

Output:

{
  "name": "Los Angeles International Airport",
  "IATA": "LAX",
  "ICAO": "KLAX",
  "location": {
    "city": "Los Angeles",
    "country": "United States",
    "coordinates": {
      "latitude": 33.9416,
      "longitude": -118.4085
    }
  },
  "timezone": "America/Los_Angeles"
}

You likely also will get better results if you apply a chat template per https://github.com/outlines-dev/outlines/issues/987

leloykun commented 3 months ago

@liqul instruction-finetuned models tend to be annoyingly template-dependent and the more they are finetuned, the worse the problem gets

imo it would also be interesting to measure, thru ablation benchmarks, how much applying/not applying the chat template affects model performance

liqul commented 3 months ago

@lapp0 Cool, I didn't realized the chat template is not by default applied in outlines. I believe this can improve the generation quality though I haven't tried it.

lapp0 commented 3 months ago

@lapp0 Cool, I didn't realized the chat template is not by default applied in outlines. I believe this can improve the generation quality though I haven't tried it.

It isn't applied by default yet, the issue still hasn't been implemented. It appears to almost always improve quality though and should be applied based on my observations :)