mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.99k stars 525 forks source link

Decompression tokens #1223

Closed milocress closed 4 months ago

milocress commented 4 months ago

For our eval gauntlet, modify the prompt to place "thinking placeholder tokens" (. or something) at the beginning and end of the prompt.

Measure: eval gauntlet performance

What I expect could happen (pre-registering)

  1. planning+decompression world (my prediction): treatment B > treatment A > control The thinking tokens in the beginning of the sequence are used to "decompress" the model into the kv cache, which normally takes place in the background over a large number of tokens, and wouldn't take place at all if the prompt is too small. Since bigger context = more flops, adding thinking tokens to the beginning of the context will cause increased performance. However, adding thinking tokens to the middle of the context, where they benefit from seeing the prompt, allows the model to inflate its kv metamodel with more relevant data. I'll call the before-prompt thinking tokens decompression tokens and the the after-prompt thinking tokens planning tokens.
  2. decompression-only world (would be the most surprising and cool): Treatment B == treatment A > control The effect often attributed to planning tokens actually is due to decompression tokens. Flops/output token is the dominant bit in determining output token quality.
  3. planning-only world (conventional wisdom): Treatment B > treatment A == control
  4. indifferent world (also conventional wisdom): (treatment B == treatment A == control)
  5. distracted world (softmax sensitivity too high): control > treatments (putting a bunch of unrelated tokens into the context window could cause it to be less accurate)
  6. disconnected world (wrong . token): Treatment A > control > treatment B In particular, separating the question and the answer with many dividing tokens will cause positional embedding or alibi or whatever to mess up.

The consequences of 1 and 2 would be that you can "grow" a small model (with no data at all) to make it more capable.

Experiment YAML

name: decompression-experiment
image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest
scheduling:
  priority: low
compute:
  gpus: 8
  cluster: auto
parameters:
  seed: 1
  models:
  - model:
      name: hf_causal_lm
      pretrained: true
      init_device: mixed
      pretrained_model_name_or_path: mosaicml/mpt-7b
    tokenizer:
      name: mosaicml/mpt-7b
      kwargs:
        model_max_length: ${max_seq_len}
    model_name: mosaicml/mpt-7b
  - model:
      name: hf_causal_lm
      pretrained: true
      init_device: mixed
      pretrained_model_name_or_path: mosaicml/mpt-7b
    tokenizer:
      name: mosaicml/mpt-7b
      kwargs:
        model_max_length: ${max_seq_len}
    model_name: mosaicml/mpt-7b-prepend-thought
    prepend_tokens: 100
  - model:
      name: hf_causal_lm
      pretrained: true
      init_device: mixed
      pretrained_model_name_or_path: mosaicml/mpt-7b
    tokenizer:
      name: mosaicml/mpt-7b
      kwargs:
        model_max_length: ${max_seq_len}
    model_name: mosaicml/mpt-7b-append-thought
    append_tokens: 100
  - model:
      name: hf_causal_lm
      pretrained: true
      init_device: mixed
      pretrained_model_name_or_path: mosaicml/mpt-7b
    tokenizer:
      name: mosaicml/mpt-7b
      kwargs:
        model_max_length: ${max_seq_len}
    model_name: mosaicml/mpt-7b-append-and-prepend-thought
    append_tokens: 100
    prepend_tokens: 100
  loggers:
  run_name: mpt-7b-hf-eval-regression
  icl_tasks: eval/yamls/tasks.yaml
  precision: amp_fp16
  fsdp_config:
    mixed_precision: FULL
    forward_prefetch: true
    limit_all_gathers: true
    sharding_strategy: FULL_SHARD
  max_seq_len: 1024
  eval_gauntlet: eval/yamls/eval_gauntlet.yaml
  test_suite_tag: llm-foundry_regressions_cron_date_2024-05-09-08-00-34
  device_eval_batch_size: 4
  icl_subset_num_batches: 20
integrations:
- integration_type: git_repo
  git_repo: milocress/llm-foundry
  git_branch: milo/decompression-tokens
  pip_install: -e .[gpu-flash2]
env_variables:
  key: value
command: |-
  cd llm-foundry/scripts/
  composer eval/eval.py /mnt/config/parameters.yaml

Results

model_name core_average lm_task_average lite_average
mosaicml/mpt-7b-append-and-prepend-thought 0.326874 0.437117 0.472216
mosaicml/mpt-7b-append-thought 0.326556 0.438966 0.471418
mosaicml/mpt-7b 0.324358 0.434596 0.465906
mosaicml/mpt-7b-prepend-thought 0.32376 0.430823 0.465113