turingmotors / heron

Apache License 2.0
163 stars 25 forks source link

Regarding the part that specifies the parameters to be trainable #24

Open Onely7 opened 1 year ago

Onely7 commented 1 year ago

I'm reaching out to discuss parameters set to be trainable.
To specify which parameters should be trainable, one should refer to the projects/OOO/expOOO.yaml file:

model_config:
    keys_to_finetune:
        - visual_projection
    keys_to_freeze: []

You must specify elements in either keys_to_finetune or keys_to_freeze (but specifying elements in both will result in a ValueError).
I was puzzled about what elements could be specified here, so I investigated.
I believe this can be understood by examining the contents of the set_trainable_params function in utils.py.

https://github.com/turingmotors/heron/blob/a52d8cfa00a6514011dd5d8c7d0b63afe7664c26/heron/models/utils.py#L159C1-L196

Also, in this set_trainable_params function of utils.py, parameters to be frozen are specified by matching substrings of model parameter names and strings in the keys_to_freeze list:

for name, p in model.named_parameters():

    ...

    elif np.any([k in name for k in keys_to_freeze]):
        p.requires_grad = False
        untrainable_list.append(name)

    ...

In other words, if you specify a string that doesn't exist in the model's modules, it's as if you didn't specify anything at all.

For example, consider checking the modules the model here has:

import torch
from transformers import AutoProcessor
from heron.models.git_llm.git_llama import GitLlamaForCausalLM
device_id = 0

model = GitLlamaForCausalLM.from_pretrained(
    'turing-motors/heron-chat-git-Llama-2-7b-v0',
    torch_dtype=torch.float16
    )

model.eval()
model.to(f"cuda:{device_id}")

print(model)
GitLlamaForCausalLM(
  (model): GitLlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (image_encoder): CLIPVisionModel(
      (vision_model): CLIPVisionTransformer(
        (embeddings): CLIPVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
          (position_embedding): Embedding(577, 1024)
        )
        (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): CLIPEncoder(
          (layers): ModuleList(
            (0-23): 24 x CLIPEncoderLayer(
              (self_attn): CLIPAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
              )
              (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): CLIPMLP(
                (activation_fn): QuickGELUActivation()
                (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                (fc2): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (visual_projection): GitProjection(
      (visual_projection): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

Additionally, to see the specific parameter names of the turing-motors/heron-chat-git-Llama-2-7b-v0 model:

for name, p in model.named_parameters():
    print(name)
model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.o_proj.weight
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight

... 

model.image_encoder.vision_model.post_layernorm.weight
model.image_encoder.vision_model.post_layernorm.bias
model.visual_projection.visual_projection.0.weight
model.visual_projection.visual_projection.0.bias
model.visual_projection.visual_projection.1.weight
model.visual_projection.visual_projection.1.bias
lm_head.weight

This allows us to see all parameter names of the turing-motors/heron-chat-git-Llama-2-7b-v0 model.
When specifying the parameters you want to be trainable (or frozen), the names of the elements you specify in keys_to_finetune (or keys_to_freeze) in projects/OOO/expOOO.yaml should match a substring of these parameter names.

For instance:

model_config:
    keys_to_finetune:
        - visual_projection
        - num_image_with_embedding
    keys_to_freeze: []

By doing this, only the parameters of the turing-motors/heron-chat-git-Llama-2-7b-v0 model that match the visual_projection element:

model.visual_projection.visual_projection.0.weight
model.visual_projection.visual_projection.0.bias
model.visual_projection.visual_projection.1.weight
model.visual_projection.visual_projection.1.bias

will be trainable, while parameters that don't match any substring will be frozen. (Given that there's no parameter name in the turing-motors/heron-chat-git-Llama-2-7b-v0 model that matches num_image_with_embedding, specifying this seems optional.)
Hence, the sample config for training the llama-based VL model here includes:

model_config:
    keys_to_finetune:
        - visual_projection
        - num_image_with_embedding
    keys_to_freeze: []

Even though num_image_with_embedding is specified, I believe it is not necessary. Is my understanding correct?
If there are any errors or misconceptions in my explanation thus far, please let me know.