Closed ottonemo closed 9 months ago
I fail to reproduce this with GPT2 and a unit test. I haven't tried using Zephyr-7B beta in the unit test.
I won't be able to work on this longer today and I don't know when I'll be able to pick this up.
Here's the unit test:
@pytest.fixture(scope='class')
def model_cacheable(self):
from transformers import AutoModelForCausalLM
return AutoModelForCausalLM.from_pretrained('gpt2')
def test_caching_works_shared_label_prefix_without_eos(self, model_cacheable, classifier_cls):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_cacheable.config._name_or_path)
tokenizer.add_eos_token = False
generate_fn = model_cacheable.generate
def patched_generate(*args, **kwargs):
original_result = generate_fn(*args, **kwargs)
no_eos_result = original_result[:, :-1]
return no_eos_result
with patch.object(model_cacheable, 'generate', new=patched_generate):
clf = classifier_cls(model=model_cacheable, tokenizer=tokenizer)
X = np.array([["Hey there"], ["No thank you"], ["Whatever"]])
y = ['0', '1', '12']
clf.fit(X, y)
for label_ids in clf.label_ids_:
assert label_ids[-1] != model_cacheable.config.eos_token_id
clf.predict(X)
clf.predict(X)
Quick update. This is a test that reproduces the issue:
def test_caching_works_shared_label_prefix_without_eos(self, model_cacheable, classifier_cls):
clf = classifier_cls('HuggingFaceH4/zephyr-7b-beta')
X = np.array([["Hey there"], ["No thank you"], ["Whatever"]])
y = ['0', '1', '12']
clf.fit(X, y)
for label_ids in clf.label_ids_:
assert label_ids[-1] != model_cacheable.config.eos_token_id
clf.predict(X)
clf.predict(X)
The issue is resolved when adding not label_id or
to the condition in generate_logits
: https://github.com/skorch-dev/skorch/blob/9ab3b2c079fbb9b74bde0e09283b17b288e8f9e7/skorch/llm/classifier.py#L246
I'm currently working on reproducing this issue with gpt2 as to not blow up the model zoo we're using in the tests too much.
Resolved with #1048.
When adding labels with a shared prefix (e.g.,
'1'
and'14'
) caching breaks if the tokenizer does not add EOS tokens to the end of the tokenized strings.Example to reproduce:
Exception:
Tokenizer info:
Not shown here is that
tokenizer.add_eos_token
isFalse
in this case which is the cause for this issue and apparently the default forLlamaTokenizer*
.The above sample can be fixed when adding the EOS token to the label ids manually:
Update: I think it is also the model that must not output an EOS token.