SparkJiao / llama-pipeline-parallel

A prototype repo for hybrid training of pipeline parallel and distributed data parallel with comments on core code snippets. Feel free to copy code and launch discussions about the problems you have encoured.
48 stars 2 forks source link

Question about "TiedLayerSpec" #8

Open xuanhua opened 1 hour ago

xuanhua commented 1 hour ago

Hi, @SparkJiao , I'm working on finetuning deepseek coder model (like 1b and 6.7b) based on model pipeline, as far as I know, it is based on the llama architecture. And this repo gives me great help. But as a beginner, I did not quite understand about the TiedLayerSpec which is provied by deepspeed library. And I saw you provide two get_model() function.

I just want to know which one should I use ?

def get_model(model):
    layers = [TiedLayerSpec("weight", EmbeddingPipeLayer, model=model, tied_weight_attr="weight"),
              *[LayerSpec(LlamaPipeLayer, model=model, layer_idx=idx) for idx in range(model.config.num_hidden_layers)],
              LayerSpec(FLNPipeLayer, model=model),
              TiedLayerSpec("weight", LMPipeLayer, model=model, tied_weight_attr="weight"),
              ]
    return layers
def get_layers_from_config(model_config, activation_checkpointing: bool = False):
    """
    `tie_word_embeddings` in LLaMA is set to `false`.
    """
    layers = [
        LayerSpec(EmbeddingPipe, model_config.vocab_size, model_config.hidden_size),
        # TiedLayerSpec("weight", EmbeddingPipe, model_config.vocab_size, model_config.hidden_size, tied_weight_attr="weight"),
        *[LayerSpec(ParallelTransformerLayerPipe, model_config, activation_checkpointing)
          for _ in range(model_config.num_hidden_layers)],
        LayerSpec(LayerNormPipe, model_config.hidden_size, model_config.rms_norm_eps),
        LayerSpec(LMLayerPipe, model_config.hidden_size, model_config.vocab_size, bias=False),
        # TiedLayerSpec("weight", LMLayerPipe, model_config.hidden_size, model_config.vocab_size, bias=False,
        #               tied_weight_attr="weight"),
        # LayerSpec(LossLayer),
    ]
    return layers

Expect your reply sincerely

SparkJiao commented 1 hour ago

I recommend using get_layers_from_config. I think there was a but when using get_model and I failed to fix it.

TiedLayerSpec is used for initializing the tied weights, as they share the gradient update. In Transoformer architecture, the weight of input embedding and lm_head sometimes will be tied. But I'm not sure if deepseek's model uses this setting, you may check their config or paper.