huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.63k stars 26.92k forks source link

`stopping_criteria` not working with llama #22436

Closed mk-cupist closed 1 year ago

mk-cupist commented 1 year ago

System Info

I am generating text from llama-13b model. But it continues generating even though it met stopping criteria. the stopping criteria works fine with other models such as GPT-J 6B.

I loaded llama-13b by model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True) and my stopping criteria list looks like below

stopping_criteria_list = transformers.StoppingCriteriaList([
        _SentinelTokenStoppingCriteria(
            sentinel_token_ids=tokenizer(
                "\n",
                add_special_tokens=False,
                return_tensors="pt",
            ).input_ids.to("cuda"),
            starting_idx=tokenized_items.input_ids.shape[-1])
    ])

Thank you.

Who can help?

No response

Information

Tasks

Reproduction

  1. load lama model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True)
  2. make stopping criteria

    stopping_criteria_list = transformers.StoppingCriteriaList([
        _SentinelTokenStoppingCriteria(
            sentinel_token_ids=tokenizer(
                "\n",
                add_special_tokens=False,
                return_tensors="pt",
            ).input_ids.to("cuda"),
            starting_idx=tokenized_items.input_ids.shape[-1])
    ])
    ...
    class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
    
    def __init__(self, sentinel_token_ids: torch.LongTensor,
                 starting_idx: int):
        transformers.StoppingCriteria.__init__(self)
        self.sentinel_token_ids = sentinel_token_ids
        self.starting_idx = starting_idx
    
    def __call__(self, input_ids: torch.LongTensor,
                 _scores: torch.FloatTensor) -> bool:
        for sample in input_ids:
            trimmed_sample = sample[self.starting_idx:]
            # Can't unfold, output is still too tiny. Skip.
            if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
                continue
    
            for window in trimmed_sample.unfold(
                    0, self.sentinel_token_ids.shape[-1], 1):
                if torch.all(torch.eq(self.sentinel_token_ids, window)):
                    return True
        return False
  3. generate
    model_output = model.generate(stopping_criteria=stopping_criteria_list, 
                                **tokenized_items, **generation_settings,
                                pad_token_id=tokenizer.eos_token_id)

Expected behavior

Stop generating when it generated \n.

sgugger commented 1 year ago

cc @gante Note that this might require #22402 as the Llama tokenizer has a few bugs we are fixing.

gante commented 1 year ago

@mk-cupist 👋

Let's see if the PR above fixes it. If it doesn't... we need to find a way to reproduce the issue with publicly available weights, otherwise it will be hell for me to figure out what's going on 😅

mk-cupist commented 1 year ago

@gante I tried the pr with decapoda-research/llama-13b-hf and changed tokenizer_config to LlamaTokenizer but it still does not work.

sgugger commented 1 year ago

That repo is based on an intermediate state of the PR done to Transformers. It cannot even work with the main branch.

mk-cupist commented 1 year ago

Is there any llama 13b that you know would worth try?

mk-cupist commented 1 year ago

I tried swype/deepshard-13B-raw which uses 4.28.0.dev0 but doesn't work neither.

gante commented 1 year ago

@mk-cupist let's wait for the resolution of #22402 :) Your issue depends on the use of the tokenizer, so it may be related

mk-cupist commented 1 year ago

@mk-cupist let's wait for the resolution of #22402 :) Your issue depends on the use of the tokenizer, so it may be related

Thank you!

Michanne commented 1 year ago

Thanks for this, it is also needed to get LLaMa performing correctly with Langchain chains.

mk-cupist commented 1 year ago

I re-converted with conversion code from the pr, but still have the same issue.

oobabooga commented 1 year ago

I can reproduce the issue. Here is some additional code for testing:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('models/llama-7b/')

>>> tokenizer.encode('\nYou:', add_special_tokens=False)
[29871, 13, 3492, 29901]

>>> tokenizer.decode([29871, 13, 3492, 29901])
' \nYou:'

>>> tokenizer.decode([13, 3492, 29901])
' \nYou:'

There is always an extra space (29871) everywhere. Also,

>>> tokenizer.encode(' ', add_special_tokens=False)
[259]

>>> tokenizer.decode([259])
'  ' # two spaces

>>> tokenizer.decode([29871]) 
' ' # one space

If you encode a space, it becomes id 259 instead of 29871. And if you decode [259], the result is two spaces.

Very confusing behavior overall.

sgugger commented 1 year ago

@oobabooga Those issues will be fixed by #22402

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

poohzaza166 commented 1 year ago

Hi i had experience the same problem and i have install transformers using git with the main branch the model seem to ignore the stop parms completely.

gante commented 1 year ago

@poohzaza166 would you be able to share a stand-alone reproduction script? :)

poohzaza166 commented 1 year ago

https://github.com/poohzaza166/daiagnosing-llama

gante commented 1 year ago

@poohzaza166 that is not a short reproducible script :) I can only give a hand if you help me pin down the issue with a short reproducer

poohzaza166 commented 1 year ago

from transformers import LlamaTokenizer
import transformers
import torch
import random

seeds = 56416
preprompt = '''Utachi's Persona: Meet utachi, a half-British, half-Japanese playful kind anime girl who loves sukiyaki and reading novels. On the side, she does a bit of programming and is a curious person who often does some odd things.
Scenario: i am hanging out on discord and someone message me
Utachi: Hi there! My name is utachi and I'm so excited to meet you all! I love exploring new things and trying out new hobbies. Do you have any recommendations for what I should try next?
pooh: Hi nice to meet you! i am pooh nice to meet you. Are you interesting in watching anime? i am watching this new show call Bochi the rock it basiclly k-on but for people with social anxiety.
Utachi: i see.'''

model = 'Neko-Institute-of-Science/pygmalion-7b'

tokenizer = LlamaTokenizer.from_pretrained(model)
model = LlamaForCausalLM.from_pretrained(model, low_cpu_mem_usage=True, load_in_8bit=True, device_map='auto',early_stopping=True,)

class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):

    def __init__(self, sentinel_token_ids: torch.LongTensor,
                 starting_idx: int):
        transformers.StoppingCriteria.__init__(self)
        self.sentinel_token_ids = sentinel_token_ids
        self.starting_idx = starting_idx

    def __call__(self, input_ids: torch.LongTensor,
                 _scores: torch.FloatTensor) -> bool:
        for sample in input_ids:
            trimmed_sample = sample[self.starting_idx:]
            # Can't unfold, output is still too tiny. Skip.
            if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
                continue

            for window in trimmed_sample.unfold(
                    0, self.sentinel_token_ids.shape[-1], 1):
                if torch.all(torch.eq(self.sentinel_token_ids, window)):
                    return True
        return False

tokenized = tokenizer(preprompt, return_tensors="pt").to('cuda')

stopping_criteria_list = transformers.StoppingCriteriaList([
    _SentinelTokenStoppingCriteria(
        sentinel_token_ids=tokenizer(
            "pooh:",
            add_special_tokens=False,
            return_tensors="pt",
        ).input_ids.to("cuda"),
        starting_idx=tokenized.input_ids.shape[-1])
])

random.seed(seeds)
torch.manual_seed(seeds)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seeds)

token = model.generate(**tokenized,
                        stopping_criteria=stopping_criteria_list,
                        do_sample=True,
                        max_new_tokens = 250, temperature=0.7, top_p=0.9, top_k = 0,typical_p = 1.0, repetition_penalty = 1.05, 
                        early_stopping=True)

output = tokenizer.decode(token[0], skip_special_tokens=True)

print(output)```

`Utachi's Persona: Meet utachi, a half-British, half-Japanese playful kind anime girl who loves sukiyaki and reading novels. On the side, she does a bit of programming and is a curious person who often does some odd things.
Scenario: i am hanging out on discord and someone message me
Utachi: Hi there! My name is utachi and I'm so excited to meet you all! I love exploring new things and trying out new hobbies. Do you have any recommendations for what I should try next?
pooh: Hi nice to meet you! i am pooh nice to meet you. Are you interesting in watching anime? i am watching this new show call Bochi the rock it basiclly k-on but for people with social anxiety.
Utachi: i see. i don't know much about anime but i've heard good things about it. i like to watch shows with interesting characters and plots. what kind of anime do you recommend?
pooh: well if you like comedy i would recommend girlish number. it is a very cute show and has a lot of comedic scenes. also if you like romance i would recommend Kimi no Na wa or Your Name. if you like action i would recommend Attack on Titan or Death Note.
Utachi: thank you so much! i will check those out and let you know my thoughts!`

#if the code was working it should have not gen the pooh: token
poohzaza166 commented 1 year ago

@poohzaza166 that is not a short reproducible script :) I can only give a hand if you help me pin down the issue with a short reproducer

sorry about that i am using mini conda virtual env


Package                  Version      Editable project location
------------------------ ------------ --------------------------------------------
absl-py                  1.4.0
accelerate               0.18.0
aiofiles                 23.1.0
aiohttp                  3.8.4
aiosignal                1.3.1
altair                   4.2.2
anyio                    3.6.2
asttokens                2.2.1
async-timeout            4.0.2
attrs                    23.1.0
backcall                 0.2.0
bitsandbytes             0.38.1
Bottleneck               1.3.5
Brotli                   1.0.9
brotlipy                 0.7.0
cachetools               5.3.0
cairocffi                1.4.0
CairoSVG                 2.5.2
cchardet                 2.1.7
certifi                  2022.12.7
cffi                     1.15.0
chardet                  5.1.0
charset-normalizer       3.1.0
chess                    1.9.4
chess-gym                0.0.5
click                    8.1.3
cloudpickle              2.2.1
cmake                    3.26.3
colorama                 0.4.6
comm                     0.1.2
contourpy                1.0.7
cryptography             3.4.8
cssselect2               0.7.0
cycler                   0.11.0
dataclasses-json         0.5.7
datasets                 2.11.0
debugpy                  1.6.6
decorator                5.1.1
defusedxml               0.7.1
dill                     0.3.6
entrypoints              0.4
et-xmlfile               1.1.0
executing                1.2.0
fastapi                  0.95.1
ffmpy                    0.3.0
filelock                 3.12.0
filetype                 1.2.0
flexgen                  0.1.7
fonttools                4.39.3
frozenlist               1.3.3
fsspec                   2023.4.0
google-auth              2.16.1
google-auth-oauthlib     0.4.6
gptcache                 0.1.21
gradio                   3.25.0
gradio_client            0.1.4
greenlet                 2.0.2
grpcio                   1.51.3
gym                      0.26.2
gym-notices              0.0.8
h11                      0.14.0
httpcore                 0.17.0
httpx                    0.24.0
huggingface-hub          0.14.1
idna                     3.4
importlib-metadata       6.0.0
inputs                   0.5
ipykernel                6.21.2
ipython                  8.11.0
jedi                     0.18.2
Jinja2                   3.1.2
joblib                   1.1.1
jsonschema               4.17.3
jupyter_client           8.0.3
jupyter_core             5.2.0
kiwisolver               1.4.4
langchain                0.0.155
linkify-it-py            2.0.0
lit                      16.0.2
llama-cpp-python         0.1.36
Markdown                 3.4.1
markdown-it-py           2.2.0
MarkupSafe               2.1.2
marshmallow              3.19.0
marshmallow-enum         1.5.1
matplotlib               3.7.1
matplotlib-inline        0.1.6
mdit-py-plugins          0.3.3
mdurl                    0.1.2
mkl-fft                  1.3.1
mkl-random               1.2.2
mkl-service              2.4.0
mpmath                   1.3.0
multidict                6.0.4
multiprocess             0.70.14
mutagen                  1.46.0
mypy-extensions          1.0.0
nest-asyncio             1.5.6
networkx                 3.1
numexpr                  2.8.4
numpy                    1.24.3
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
oauthlib                 3.2.2
openai                   0.23.1
openapi-schema-pydantic  1.2.4
openpyxl                 3.0.9
orjson                   3.8.11
packaging                23.1
pandas                   2.0.1
pandas-stubs             2.0.0.230412
parso                    0.8.3
peft                     0.3.0.dev0
pexpect                  4.8.0
pickleshare              0.7.5
Pillow                   9.5.0
pip                      23.0.1
platformdirs             3.1.0
prompt-toolkit           3.0.38
protobuf                 4.22.0
psutil                   5.9.5
ptyprocess               0.7.0
PuLP                     2.7.0
pure-eval                0.2.2
pyarrow                  11.0.0
pyasn1                   0.4.8
pyasn1-modules           0.2.8
pycparser                2.21
pydantic                 1.10.7
pydub                    0.25.1
Pygments                 2.14.0
pynput                   1.7.6
pyOpenSSL                20.0.1
pyparsing                3.0.9
pyrsistent               0.19.3
PySide2                  5.15.2.1
PySide6-Essentials       6.4.1
PySocks                  1.7.1
python-chess             1.999
python-dateutil          2.8.2
python-multipart         0.0.6
python-xlib              0.31
pytz                     2023.3
PyYAML                   6.0
pyzmq                    25.0.0
regex                    2023.3.23
requests                 2.29.0
requests-oauthlib        1.3.1
responses                0.18.0
rsa                      4.9
rwkv                     0.7.3
sacremoses               0.0.43
safetensors              0.3.0
semantic-version         2.10.0
sentencepiece            0.1.98
setuptools               66.0.0
shiboken2                5.15.2.1
shiboken6                6.4.1
six                      1.16.0
sniffio                  1.3.0
SQLAlchemy               2.0.11
stack-data               0.6.2
starlette                0.26.1
streamdeck               0.9.3
streamdeck-ui            2.0.6
stringcase               1.2.0
sympy                    1.11.1
tenacity                 8.2.2
tensorboard              2.12.0
tensorboard-data-server  0.7.0
tensorboard-plugin-wit   1.8.1
tinycss2                 1.2.1
tokenizers               0.13.3
toolz                    0.12.0
torch                    2.0.0
tornado                  6.2
tqdm                     4.65.0
traitlets                5.9.0
transformers             4.29.0.dev0  /mnt/sharessd/code/python/QABOT/transformers
triton                   2.0.0
types-pytz               2023.3.0.0
typing_extensions        4.5.0
typing-inspect           0.8.0
tzdata                   2023.3
uc-micro-py              1.0.1
urllib3                  1.26.15
uvicorn                  0.22.0
wcwidth                  0.2.6
webencodings             0.5.1
websockets               10.4
Werkzeug                 2.2.3
wheel                    0.38.4
xxhash                   3.2.0
yarl                     1.9.2
yt-dlp                   2023.2.17
zipp                     3.11.0```
gante commented 1 year ago

Hey @poohzaza166 👋

I had a look at your snippet, and the problem does not step from the stopping criteria nor the llama model itself, but rather from how the tokenizer works. It also doesn't seem to be a bug. My recommendation would be to design the stopping criteria from the token ids, and not from raw text :)

See this example:

Click me ```python from transformers import LlamaTokenizer import transformers import torch tokenizer = LlamaTokenizer.from_pretrained('huggyllama/llama-7b') class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): def __init__(self, sentinel_token_ids: torch.LongTensor, starting_idx: int): transformers.StoppingCriteria.__init__(self) self.sentinel_token_ids = sentinel_token_ids self.starting_idx = starting_idx def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: for sample in input_ids: trimmed_sample = sample[self.starting_idx:] # Can't unfold, output is still too tiny. Skip. if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: continue for window in trimmed_sample.unfold(0, self.sentinel_token_ids.shape[-1], 1): if torch.all(torch.eq(self.sentinel_token_ids, window)): return True return False sentinel_token_ids = tokenizer("pooh:", add_special_tokens=False, return_tensors="pt").input_ids.to("cuda") print(sentinel_token_ids) stopping_criteria_list = transformers.StoppingCriteriaList([ _SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=0) ]) test_input_1 = """This is a test.\npooh: potato.""" test_input_ids = tokenizer(test_input_1, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda") print(stopping_criteria_list(test_input_ids, None)) test_input_2 = """This is a test. pooh: potato.""" test_input_ids = tokenizer(test_input_2, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda") print(stopping_criteria_list(test_input_ids, None)) ```
poohzaza166 commented 1 year ago

@gante Hi thanks for the help. Though now i have a problem where if i fix the stop condition to token id sometime there multiple token that produce the same plaintext stop word. is there a way to get around this?

my orginal idea for this is to just stream the genration and append the word to a string and use regex to halt the loop with it detect the stop token in plain text. though this seem janky is there a "proper way to do this"

gante commented 1 year ago

@poohzaza166 we do not have a solution for that problem, but as always you can design a custom stopping criteria -- nothing prevents you to expand the code you shared to check against multiple stop sequences :D

(and yes, it is better to do it at a token level, otherwise you need to pass the tokens back to the CPU and decode them, which will slow generation down significantly)

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.