skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.84k stars 388 forks source link

LLM caching breaks with shared-prefix labels under certain conditions #1047

Closed ottonemo closed 7 months ago

ottonemo commented 7 months ago

When adding labels with a shared prefix (e.g., '1' and '14') caching breaks if the tokenizer does not add EOS tokens to the end of the tokenized strings.

Example to reproduce:

from transformers import AutoTokenizer, AutoModelForCausalLM
from skorch.llm import ZeroShotClassifier

model_id = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
)
model = model.to('cuda')

clf = ZeroShotClassifier(
    model=model,
    tokenizer=tokenizer,
    device='cuda',
    use_caching=True,
)

X = np.array([["Hey there"], ["No thank you"], ["Whatever"]])
y = ['0', '1', '12']

clf.fit(X, y)
clf.predict(["Hey there"]) # works
clf.predict(["Hey there"]) # throws IndexError

Exception:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
----> 1 clf.predict(["Hey there"])

File ~.../site-packages/skorch/llm/classifier.py:537, in _LlmBase.predict(self, X)
    507 """Return the classes predicted by the LLM.
    508 
    509 Predictions will be forced to be one of the labels the model learned
   (...)
    534 
    535 """
    536 # y_proba not normalized but it's not neeeded here
--> 537 y_proba = self._predict_proba(X)
    538 pred_ids = y_proba.argmax(1)
    539 y_pred = self.classes_[pred_ids]

File ~.../site-packages/skorch/llm/classifier.py:451, in _LlmBase._predict_proba(self, X)
    449 for xi in X:
    450     text = self.get_prompt(xi)
--> 451     proba = self._predict_one(text)
    452     y_proba.append(proba)
    453 y_proba = np.vstack(y_proba)

File ~.../site-packages/skorch/llm/classifier.py:407, in _LlmBase._predict_one(self, text)
    405 probas_all_labels = []
    406 for label_id in self.label_ids_:
--> 407     logits = self.cached_model_.generate_logits(label_id=label_id, **inputs)
    408     logits = torch.vstack(logits)
    409     probas = torch.nn.functional.softmax(logits, dim=-1)

File ~.../site-packages/skorch/llm/classifier.py:246, in _CacheModelWrapper.generate_logits(self, label_id, **kwargs)
    244 logits_cached = self.get_cache(kwargs)
    245 while logits_cached is not None:
--> 246     if label_id[0] == self.tokenizer.eos_token_id:
    247         # don't extend with eos_token -- it is already there at the end,
    248         # we don't need it twice
    249         break
    251     recorded_logits.append(logits_cached)

IndexError: list index out of range

Tokenizer info:

LlamaTokenizerFast(name_or_path='HuggingFaceH4/zephyr-7b-beta', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='left', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>', 'additional_special_tokens': ['<unk>', '<s>', '</s>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
    0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

Not shown here is that tokenizer.add_eos_token is False in this case which is the cause for this issue and apparently the default for LlamaTokenizer*.

The above sample can be fixed when adding the EOS token to the label ids manually:

clf.fit(X, y)
clf.label_ids_ = [n + [tokenizer.eos_token_id] for n in clf.label_ids_]
...

Update: I think it is also the model that must not output an EOS token.

ottonemo commented 7 months ago

I fail to reproduce this with GPT2 and a unit test. I haven't tried using Zephyr-7B beta in the unit test.

I won't be able to work on this longer today and I don't know when I'll be able to pick this up.

Here's the unit test:

    @pytest.fixture(scope='class')
    def model_cacheable(self):
        from transformers import AutoModelForCausalLM
        return AutoModelForCausalLM.from_pretrained('gpt2')

    def test_caching_works_shared_label_prefix_without_eos(self, model_cacheable, classifier_cls):
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_cacheable.config._name_or_path)
        tokenizer.add_eos_token = False

        generate_fn = model_cacheable.generate

        def patched_generate(*args, **kwargs):
            original_result = generate_fn(*args, **kwargs)
            no_eos_result = original_result[:, :-1]
            return no_eos_result

        with patch.object(model_cacheable, 'generate', new=patched_generate):
            clf = classifier_cls(model=model_cacheable, tokenizer=tokenizer)

            X = np.array([["Hey there"], ["No thank you"], ["Whatever"]])
            y = ['0', '1', '12']

            clf.fit(X, y)

            for label_ids in clf.label_ids_:
                assert label_ids[-1] != model_cacheable.config.eos_token_id

            clf.predict(X)
            clf.predict(X)
ottonemo commented 7 months ago

Quick update. This is a test that reproduces the issue:

    def test_caching_works_shared_label_prefix_without_eos(self, model_cacheable, classifier_cls):
        clf = classifier_cls('HuggingFaceH4/zephyr-7b-beta')

        X = np.array([["Hey there"], ["No thank you"], ["Whatever"]])
        y = ['0', '1', '12']

        clf.fit(X, y)

        for label_ids in clf.label_ids_:
            assert label_ids[-1] != model_cacheable.config.eos_token_id

        clf.predict(X)
        clf.predict(X)

The issue is resolved when adding not label_id or to the condition in generate_logits: https://github.com/skorch-dev/skorch/blob/9ab3b2c079fbb9b74bde0e09283b17b288e8f9e7/skorch/llm/classifier.py#L246

I'm currently working on reproducing this issue with gpt2 as to not blow up the model zoo we're using in the tests too much.

ottonemo commented 7 months ago

Resolved with #1048.