dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
9.06k stars 457 forks source link

JSON fails but regex does not #1232

Open cpfiffer opened 3 days ago

cpfiffer commented 3 days ago

Describe the issue as clearly as possible:

When using outlines with the Llama 3.2 Vision model, simple regex pattern generation works, but JSON schema-based generation fails with index out of bounds errors.

Steps/code to reproduce the bug:

from io import BytesIO
import json
from urllib.request import urlopen
from PIL import Image
from pydantic import BaseModel
import outlines
from transformers import MllamaForConditionalGeneration, AutoProcessor
from transformers import AutoTokenizer
from outlines.fsm.json_schema import build_regex_from_schema

def img_from_url(url):
    img_byte_stream = BytesIO(urlopen(url).read())
    return Image.open(img_byte_stream).convert("RGB")

image_url="https://upload.wikimedia.org/wikipedia/commons/9/98/Aldrin_Apollo_11_original.jpg"
image= img_from_url(image_url)

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id) 
model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

llm = outlines.models.TransformersVision(model, tokenizer,processor)

# Works
pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto"
planet_generator = outlines.generate.regex(llm, pattern)
out = planet_generator(
    "What: <|image|>",
    [image]
)
print(out)
# Fails
class ImageData(BaseModel):
    caption: str

im_reg = build_regex_from_schema(json.dumps(ImageData.model_json_schema()))
print(im_reg)
generator = outlines.generate.regex(llm, im_reg)

out = generator(
    "What: <|image|>",
    [image]
)
print(out)

Expected result:

Mercury, Venus, Earth, etc.

Error message:

On CPU:

> python vision.py                                              ~/dottxt/outlines/debug
Loading checkpoint shards: 100%|██████████████████████████| 5/5 [00:42<00:00,  8.41s/it]
Pluto
\{[ ]?"caption"[ ]?:[ ]?"([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])*"[ ]?\}
Traceback (most recent call last):
  File "/home/cameron/dottxt/outlines/debug/vision.py", line 44, in <module>
    out = generator(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/generate/api.py", line 556, in __call__
    completions = self.model.generate(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/models/transformers_vision.py", line 56, in generate
    generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/models/transformers.py", line 350, in _generate_output_seq
    output_ids = self.model.generate(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/transformers/generation/utils.py", line 3223, in _sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 104, in __call__
    scores = processor(input_ids, scores)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/processors/base_logits_processor.py", line 78, in __call__
    processed_logits = self.process_logits(input_ids, torch_logits)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/processors/structured.py", line 121, in process_logits
    mask[batch_indices_concat, allowed_tokens_concat] = False
IndexError: index 128256 is out of bounds for dimension 1 with size 128256

On GPU:

> python vision.py                                              ~/dottxt/outlines/debug
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████████████████████| 5/5 [00:12<00:00,  2.45s/it]
Saturn
\{[ ]?"caption"[ ]?:[ ]?"([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])*"[ ]?\}
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [169,0,0], thread: [12,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Traceback (most recent call last):
  File "/home/cameron/dottxt/outlines/debug/vision.py", line 44, in <module>
    out = generator(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/generate/api.py", line 556, in __call__
    completions = self.model.generate(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/models/transformers_vision.py", line 56, in generate
    generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/outlines/models/transformers.py", line 350, in _generate_output_seq
    output_ids = self.model.generate(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
  File "/home/cameron/dottxt/outlines/debug/.pyron/lib/python3.10/site-packages/transformers/generation/utils.py", line 3249, in _sample
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Outlines/Python version information:

Version information

``` accelerate==1.0.1 aiohappyeyeballs==2.4.3 aiohttp==3.10.10 aiosignal==1.3.1 airportsdata==20241001 annotated-types==0.7.0 async-timeout==4.0.3 attrs==24.2.0 certifi==2024.8.30 charset-normalizer==3.4.0 cloudpickle==3.1.0 datasets==3.0.2 dill==0.3.8 diskcache==5.6.3 filelock==3.16.1 frozenlist==1.5.0 fsspec==2024.9.0 huggingface-hub==0.26.1 idna==3.10 interegular==0.3.3 Jinja2==3.1.4 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 lark==1.2.2 MarkupSafe==3.0.2 mpmath==1.3.0 multidict==6.1.0 multiprocess==0.70.16 nest-asyncio==1.6.0 networkx==3.4.2 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.6.77 nvidia-nvtx-cu12==12.1.105 outlines==0.1.1 outlines_core==0.1.14 packaging==24.1 pandas==2.2.3 pillow==11.0.0 propcache==0.2.0 psutil==6.1.0 pyarrow==17.0.0 pycountry==24.6.1 pydantic==2.9.2 pydantic_core==2.23.4 python-dateutil==2.9.0.post0 pytz==2024.2 PyYAML==6.0.2 referencing==0.35.1 regex==2024.9.11 requests==2.32.3 rpds-py==0.20.0 safetensors==0.4.5 six==1.16.0 sympy==1.13.3 tokenizers==0.20.1 torch==2.4.0 tqdm==4.66.5 transformers==4.46.0 triton==3.0.0 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.2.3 xxhash==3.5.0 yarl==1.16.0 ```

Context for the issue:

From HatterNoMad on the Outlines Discord.

rlouf commented 3 days ago

<Please write a descriptive title> 😄

cpfiffer commented 3 days ago

Woops, issue name updated.