EleutherAI / elk

Keeping language models honest by directly eliciting knowledge encoded in their activations.
MIT License
178 stars 33 forks source link

Allow prompt templates to say if they should be binarized #234

Closed norabelrose closed 1 year ago

norabelrose commented 1 year ago

Adds a binarize field to prompt template YAML files. This is sorta needed for sweeps in which some datasets need to be binarized and others do not.

lauritowal commented 1 year ago

Getting errors right now:

(.venv) laurito@ipe-monster:~/elk$ elk sweep --models gpt2 --datasets ag_news --max_examples 10 10 --num_gpus 1
Starting sweep over 1 models and 1 datasets (1 runs)
Models: ['gpt2']
Datasets: ['ag_news']
Saving sweep results to /home/wombat_share/laurito/elk_reporters/sweeps/stoic-merkle
===== gpt2 (1 of 1) =====
Using 1 of 7 GPUs: [1]
ag_news using 'train' for training and 'test' for validation
Found cached dataset generator (/home/wombat_share/laurito/.hugginface/datasets/generator/default-b9aa46e9e0062f0d/0.0.0)
Found cached dataset generator (/home/wombat_share/laurito/.hugginface/datasets/generator/default-3941ce3946a68d45/0.0.0)
Output directory at /home/wombat_share/laurito/elk_reporters/sweeps/stoic-merkle/gpt2/ag_news
  0%|                                                                                                                                                                        | 0/13 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/laurito/elk/.venv/bin/elk", line 8, in <module>
    sys.exit(run())
  File "/home/laurito/elk/elk/__main__.py", line 27, in run
    run.execute()
  File "/home/laurito/elk/elk/__main__.py", line 19, in execute
    return self.command.execute()
  File "/home/laurito/elk/elk/training/sweep.py", line 81, in execute
    run.execute()
  File "/home/laurito/elk/elk/run.py", line 98, in execute
    self.apply_to_layers(func=func, num_devices=num_devices)
  File "/home/laurito/elk/elk/run.py", line 182, in apply_to_layers
    for df_dict in tqdm(mapper(func, layers), total=len(layers)):
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/home/laurito/elk/elk/training/train.py", line 57, in apply_to_layer
    train_dict = self.prepare_data(device, layer, "train")
  File "/home/laurito/elk/elk/run.py", line 138, in prepare_data
    val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2779, in __getitem__
    return self._getitem(key)
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2764, in _getitem
    formatted_output = format_table(
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 624, in format_table
    return formatter(pa_table, query_type=query_type)
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 398, in __call__
    return self.format_column(pa_table)
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py", line 86, in format_column
    column = self.numpy_arrow_extractor().extract_column(pa_table)
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 161, in extract_column
    return self._arrow_array_to_numpy(pa_table[pa_table.column_names[0]])
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 171, in _arrow_array_to_numpy
    array: List = [
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 172, in <listcomp>
    row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)
  File "/home/laurito/elk/.venv/lib/python3.10/site-packages/datasets/features/features.py", line 726, in to_numpy
    numpy_arr = numpy_arr.reshape(len(self) - len(null_indices), *self.type.shape)
ValueError: cannot reshape array of size 230400 into shape (10,15,4,768)

Having a look at it too now