Closed linhdvu14 closed 1 year ago
(bump)
Hey! Thanks for opening this issue! Seems to rather be related to this line, where we define the sequence length tensor. Most of our models that compute partial pooled logits use this. Can you try something like
if input_ids is not None:
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(logits.device)
I'll open a PR to fix it!
Thanks @ArthurZucker, the fix works great.
Seems the PR misses a few models: biogpt, bloom, falcon, mpt.
There was a follow up PR: #25085, might have forgotten other models!
System Info
transformers
version: 4.30.0.dev0Who can help?
text models: @ArthurZucker and @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I found that
BloomForSequenceClassification
(possibly also other causal models) produces non-deterministic outputs based onmax_length
when tokenizerpadding_side = "left"
.It might be caused by this line: https://github.com/huggingface/transformers/blob/v4.30.1/src/transformers/models/bloom/modeling_bloom.py#L1080 which seems to assume right padding.
If this diagnostic is correct, imho it's quite unintuitive and error-prone, as: 1) bloom's default
padding_side
isleft
, and 2) many tutorials (e.g. peft P-tuning for sequence classification) recommend settingpadding_side = "left"
for causal models.Could you provide some guidance? What's the correct way to use causal models for sequence classification?
Sample to reproduce:
Expected behavior
Model should produce the same outputs regardless of padding length