EleutherAI / elk

Keeping language models honest by directly eliciting knowledge encoded in their activations.
MIT License
182 stars 33 forks source link

Encoder-only models do not expect `labels` argument to be passed to forward #262

Closed AugustasMacijauskas closed 1 year ago

AugustasMacijauskas commented 1 year ago

Reproduce

elicit microsoft/deberta-v2-xxlarge <dataset>

or

elicit allenai/unifiedqa-v2-t5-11b-1363200 <dataset> --use_encoder_states

Full stack trace:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /fsx/home-augustas/anaconda3/envs/elk/lib/python3.10/site-packages/datasets/ │
│ builder.py:1629 in _prepare_split_single                                     │
│                                                                              │
│   1626 │   │   │   )                                                         │
│   1627 │   │   │   try:                                                      │
│   1628 │   │   │   │   _time = time.time()                                   │
│ ❱ 1629 │   │   │   │   for key, record in generator:                         │
│   1630 │   │   │   │   │   if max_shard_size is not None and writer._num_byt │
│   1631 │   │   │   │   │   │   num_examples, num_bytes = writer.finalize()   │
│   1632 │   │   │   │   │   │   writer.close()                                │
│                                                                              │
│ /fsx/home-augustas/elk/elk/extraction/generator.py:86 in _generate_examples  │
│                                                                              │
│   83 │   def _generate_examples(self, **gen_kwargs):                         │
│   84 │   │   assert self.config.generator is not None, "generator must be sp │
│   85 │   │                                                                   │
│ ❱ 86 │   │   for idx, ex in enumerate(self.config.generator(**gen_kwargs)):  │
│   87 │   │   │   yield idx, ex                                               │
│   88                                                                         │
│                                                                              │
│ /fsx/home-augustas/elk/elk/extraction/extraction.py:339 in                   │
│ _extraction_worker                                                           │
│                                                                              │
│   336                                                                        │
│   337 # Dataset.from_generator wraps all the arguments in lists, so we unpac │
│   338 def _extraction_worker(**kwargs):                                      │
│ ❱ 339 │   yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) │
│   340                                                                        │
│   341                                                                        │
│   342 def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]:     │
│                                                                              │
│ /fsx/home-augustas/anaconda3/envs/elk/lib/python3.10/site-packages/torch/uti │
│ ls/_contextlib.py:35 in generator_context                                    │
│                                                                              │
│    32 │   │   try:                                                           │
│    33 │   │   │   # Issuing `None` to a generator fires it up                │
│    34 │   │   │   with ctx_factory():                                        │
│ ❱  35 │   │   │   │   response = gen.send(None)                              │
│    36 │   │   │                                                              │
│    37 │   │   │   while True:                                                │
│    38 │   │   │   │   try:                                                   │
│                                                                              │
│ /fsx/home-augustas/elk/elk/extraction/extraction.py:287 in extract_hiddens   │
│                                                                              │
│   284 │   │   │   │   │   variant_questions.append(text)                     │
│   285 │   │   │   │                                                          │
│   286 │   │   │   │   inputs = dict(input_ids=ids.long(), labels=labels)     │
│ ❱ 287 │   │   │   │   outputs = model(**inputs, output_hidden_states=True)   │
│   288 │   │   │   │                                                          │
│   289 │   │   │   │   # Compute the log probability of the answer tokens if  │
│   290 │   │   │   │   if has_lm_preds:                                       │
│                                                                              │
│ /fsx/home-augustas/anaconda3/envs/elk/lib/python3.10/site-packages/torch/nn/ │
│ modules/module.py:1501 in _call_impl                                         │
│                                                                              │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                      │
│   1502 │   │   # Do not call functions when jit is used                      │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1504 │   │   backward_pre_hooks = []                                       │
│                                                                              │
│ /fsx/home-augustas/anaconda3/envs/elk/lib/python3.10/site-packages/accelerat │
│ e/hooks.py:165 in new_forward                                                │
│                                                                              │
│   162 │   │   │   with torch.no_grad():                                      │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                  │
│   164 │   │   else:                                                          │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                      │
│   166 │   │   return module._hf_hook.post_forward(module, output)            │
│   167 │                                                                      │
│   168 │   module.forward = new_forward                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
TypeError: T5Stack.forward() got an unexpected keyword argument 'labels'
artkpv commented 1 year ago

Foward method doesn't have labels. See https://huggingface.co/docs/transformers/model_doc/deberta-v2#transformers.DebertaV2Model.forward

AugustasMacijauskas commented 1 year ago

@artkpv Yes, this is precisely what I think is the cause of the problem, the labels are passed in when they should not be, see: https://github.com/EleutherAI/elk/blob/ec2b8a0ab27214a325954a7acc68dc01085642c4/elk/extraction/extraction.py#L286