OpenMOSS / CoLLiE

Collaborative Training of Large Language Models in an Efficient Way
https://openlmlab-collie.readthedocs.io
Apache License 2.0
405 stars 58 forks source link

The interpetation about the transposition operation when spliting weight to tensor parallel group #154

Closed SparkJiao closed 6 months ago

SparkJiao commented 6 months ago

Hi, thanks for your contribution very much!

I have a question about the following code snippet:

https://github.com/OpenLMLab/collie/blob/main/collie/models/llama/model.py#L586-L612

                            # 对 q_proj.weight 和 k_proj.weight 进行 reshape
                            if key.endswith("q_proj.weight"):
                                part_state_dict[key] = (
                                    rearrange(
                                        part_state_dict[key],
                                        "(h two t) d -> h two t d",
                                        h=config.num_attention_heads,
                                        two=2,
                                    )
                                    .transpose(1, 2)
                                    .reshape(config.hidden_size, config.hidden_size)
                                )
                            elif key.endswith("k_proj.weight"):
                                part_state_dict[key] = (
                                    rearrange(
                                        part_state_dict[key],
                                        "(h two t) d -> h two t d",
                                        h=num_key_value_heads,
                                        two=2,
                                    )
                                    .transpose(1, 2)
                                    .reshape(
                                        num_key_value_heads * head_dim,
                                        config.hidden_size,
                                    )
                                )

Why the transposition operation is requires here and while v_proj does not require similar operations?

Thanks for your reply very much!

KaiLv69 commented 6 months ago

Hi. The forward process of llama implemented in collie is consistent with that in https://github.com/facebookresearch/llama, but it is inconsistent with the implementation and weight storage format in the Transformers repository. In order to load and store the llama weights in the transformers format in collie, we need to reshape the q and k matrices in the llama model when loading and saving the model weights.

SparkJiao commented 6 months ago

Got it.Thanks very much!