mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.83k stars 502 forks source link

Allows interweaving of arbitrary kinds of 'attention' layers, like sliding window, reuse prev layer kv cache etc. #1299

Closed ShashankMosaicML closed 3 days ago

ShashankMosaicML commented 1 week ago

This allows for overriding default block configs for certain layers. This should contain two sub configs: order and overrides. order specifies the order of different kinds of layers (default refers to a layer that does not apply any overrides). For each kind of layer, specify the overrides in the overrides config. For example, to specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed:

model:
    ...
    (usual model configs)
    ...
    block_overrides:
        order:
        - name: default
        - order:
          - name: sliding_window_layer
          - name: sliding_window_layer_reuse
          - name: sliding_window_layer
          - repeat: 2
            name: sliding_window_layer_reuse
          - name: reuse_kv_layer
          repeat: 2
        overrides:
            sliding_window_layer:
                attn_config:
                    sliding_window_size: 1024
            sliding_window_layer_reuse:
                attn_config:
                    sliding_window_size: 1024
                    reuse_kv_layer_idx: -1 # Relative index of the layer whose kv cache to reuse
            reuse_kv_layer:
                attn_config:
                    reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse

Also prints the following log summarizing the network:

INFO: llmfoundry.models.mpt.modeling_mpt: The following is a summary of overrides per layer.
  idx  name                        overrides
-----  --------------------------  ----------------------------------------------------------
    0  default                     []
    1  sliding_window_layer        [{'sliding_window_size': 1024}]
    2  sliding_window_layer_reuse  [{'sliding_window_size': 1024}, {'reuse_kv_layer_idx': 1}]
    3  sliding_window_layer        [{'sliding_window_size': 1024}]
    4  sliding_window_layer_reuse  [{'sliding_window_size': 1024}, {'reuse_kv_layer_idx': 3}]
    5  sliding_window_layer_reuse  [{'sliding_window_size': 1024}, {'reuse_kv_layer_idx': 3}]
    6  reuse_kv_layer              [{'reuse_kv_layer_idx': 0}]
    7  sliding_window_layer        [{'sliding_window_size': 1024}]
    8  sliding_window_layer_reuse  [{'sliding_window_size': 1024}, {'reuse_kv_layer_idx': 7}]
    9  sliding_window_layer        [{'sliding_window_size': 1024}]
   10  sliding_window_layer_reuse  [{'sliding_window_size': 1024}, {'reuse_kv_layer_idx': 9}]
   11  sliding_window_layer_reuse  [{'sliding_window_size': 1024}, {'reuse_kv_layer_idx': 9}]
   12  reuse_kv_layer              [{'reuse_kv_layer_idx': 0}]

Note that the table above prints the absolute layer index for reuse_kv_layer_idx

ShashankMosaicML commented 4 days ago

mostly lgtm, are you done testing it?

Yes, we have finished testing this. Everything seems fine.