EleutherAI / elk

Keeping language models honest by directly eliciting knowledge encoded in their activations.
MIT License
186 stars 33 forks source link

Cannot run README command `elk elicit microsoft/deberta-v2-xxlarge-mnli imdb` #256

Open rusheb opened 1 year ago

rusheb commented 1 year ago

Summary

Hi there,

I was attempting to set up the application based on the Quick Start guide in the README. When trying to run the elk elicit microsoft/deberta-v2-xxlarge-mnli imdb command, the program terminates with a ValueError.

Can you please help me understand what's going wrong and how I can fix it?

Thanks!

Steps to reproduce

Environment

Commands run

python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev]"
python --version  # outputs Python 3.11.2
# Note: I get a similar error when running other commands from the README
elk elicit microsoft/deberta-v2-xxlarge-mnli imdb

I also tried a couple of suggestions from @KayKozaronek, and got the same error:

Error output

ValueError: mutable default <class 'elk.training.train.Elicit'> for field run_template is not allowed: use default_factory

(Full stack trace at the bottom of this post)

Mitigation

GPT tells me we should fix this by replacing the following property from sweep.py with a default_factory, but without more context on the project I can't tell whether this is the right solution.

    run_template: Elicit = Elicit(
        data=Extract(
            model="<placeholder>",
            datasets=("<placeholder>",),
        )
    )

Full stack trace

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /Users/rusheb/code/elk/venv/bin/elk:5 in <module>                                                │
│                                                                                                  │
│   2 # -*- coding: utf-8 -*-                                                                      │
│   3 import re                                                                                    │
│   4 import sys                                                                                   │
│ ❱ 5 from elk.__main__ import run                                                                 │
│   6 if __name__ == '__main__':                                                                   │
│   7 │   sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])                         │
│   8 │   sys.exit(run())                                                                          │
│                                                                                                  │
│ /Users/rusheb/code/elk/elk/__main__.py:9 in <module>                                             │
│                                                                                                  │
│    6                                                                                             │
│    7 from elk.evaluation.evaluate import Eval                                                    │
│    8 from elk.plotting.command import Plot                                                       │
│ ❱  9 from elk.training.sweep import Sweep                                                        │
│   10 from elk.training.train import Elicit                                                       │
│   11                                                                                             │
│   12                                                                                             │
│                                                                                                  │
│ /Users/rusheb/code/elk/elk/training/sweep.py:28 in <module>                                      │
│                                                                                                  │
│    25 │   │   get_dataset_config_info(dataset_name)                                              │
│    26                                                                                            │
│    27                                                                                            │
│ ❱  28 @dataclass                                                                                 │
│    29 class Sweep:                                                                               │
│    30 │   models: list[str]                                                                      │
│    31 │   """List of Huggingface model strings to sweep over."""                                 │
│                                                                                                  │
│ /opt/homebrew/Cellar/python@3.11/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3. │
│ 11/dataclasses.py:1220 in dataclass                                                              │
│                                                                                                  │
│   1217 │   │   return wrap                                                                       │
│   1218 │                                                                                         │
│   1219 │   # We're called as @dataclass without parens.                                          │
│ ❱ 1220 │   return wrap(cls)                                                                      │
│   1221                                                                                           │
│   1222                                                                                           │
│   1223 def fields(class_or_instance):                                                            │
│                                                                                                  │
│ /opt/homebrew/Cellar/python@3.11/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3. │
│ 11/dataclasses.py:1210 in wrap                                                                   │
│                                                                                                  │
│   1207 │   """                                                                                   │
│   1208 │                                                                                         │
│   1209 │   def wrap(cls):                                                                        │
│ ❱ 1210 │   │   return _process_class(cls, init, repr, eq, order, unsafe_hash,                    │
│   1211 │   │   │   │   │   │   │     frozen, match_args, kw_only, slots,                         │
│   1212 │   │   │   │   │   │   │     weakref_slot)                                               │
│   1213                                                                                           │
│                                                                                                  │
│ /opt/homebrew/Cellar/python@3.11/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3. │
│ 11/dataclasses.py:958 in _process_class                                                          │
│                                                                                                  │
│    955 │   │   │   kw_only = True                                                                │
│    956 │   │   else:                                                                             │
│    957 │   │   │   # Otherwise it's a field of some type.                                        │
│ ❱  958 │   │   │   cls_fields.append(_get_field(cls, name, type, kw_only))                       │
│    959 │                                                                                         │
│    960 │   for f in cls_fields:                                                                  │
│    961 │   │   fields[f.name] = f                                                                │
│                                                                                                  │
│ /opt/homebrew/Cellar/python@3.11/3.11.2_1/Frameworks/Python.framework/Versions/3.11/lib/python3. │
│ 11/dataclasses.py:815 in _get_field                                                              │
│                                                                                                  │
│    812 │   # indicator for mutability.  Read the __hash__ attribute from the class,              │
│    813 │   # not the instance.                                                                   │
│    814 │   if f._field_type is _FIELD and f.default.__class__.__hash__ is None:                  │
│ ❱  815 │   │   raise ValueError(f'mutable default {type(f.default)} for field '                  │
│    816 │   │   │   │   │   │    f'{f.name} is not allowed: use default_factory')                 │
│    817 │                                                                                         │
│    818 │   return f                                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯

norabelrose commented 1 year ago

I'm not really sure what's wrong here. It may be a 3.11 only issue. I haven't had a chance to test myself

dribnet commented 3 months ago

also cannot run command listed in README on linux, though my error is different:

$ elk elicit microsoft/deberta-v2-xxlarge-mnli imdb
Traceback (most recent call last):
  ...
  File "/mnt/md1/nets/elk/elk/extraction/extraction.py", line 336, in extract
    raise ValueError("Can only extract LM predictions from autoregressive models.")
ValueError: Can only extract LM predictions from autoregressive models.

update: I found replacing the model with gpt2 worked fine, after a minor tweak to fix an exception name issue.