Open SxJyJay opened 3 months ago
I find there may be a bug in https://github.com/Alpha-VLLM/Lumina-mGPT/blob/c8e180aa20f0a5977bf168424f30aa2be58fad94/lumina_mgpt/model/modeling_xllmx_chameleon.py#L56 Because the full Chameleon model consists of:
ChameleonModel(
(embed_tokens): Embedding(65536, 4096)
(layers): ModuleList(
(0-31): 32 x ChameleonDecoderLayer(
(self_attn): ChameleonSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(q_norm): ChameleonLayerNorm((128,), eps=1e-05, elementwise_affine=True)
(k_norm): ChameleonLayerNorm((128,), eps=1e-05, elementwise_affine=True)
(rotary_emb): ChameleonRotaryEmbedding()
)
(mlp): ChameleonMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): ChameleonRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): ChameleonRMSNorm((4096,), eps=1e-05)
(dropout): Dropout(p=0.05, inplace=False)
)
)
(norm): ChameleonRMSNorm((4096,), eps=1e-05)
)
However, modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens]
ignores self.model.norm
part. This causes the above error when saving the checkpoint. After I modify this line into:
modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens, self.model.norm]
the training can proceed.
That's weird. The "get_fsdp_wrap_module_list" method is used for the auto_wrap_policy argument in the FSDP call:
Note that FSDP wrapping is a recursive process, which means not only the outmost model, but some of the inner submodules, are also wrapped into FSDP modules. Operations like parameter sharding, gather, and flattening are then conducted at the FSDP-module level.
Importantly, the auto_wrap_policy
argument is used to define "which sub-modules should be independently wrapped into new FSDP modules", rather than "which modules should be considered as part of the model". So self.model.norm
is absent in the list merely means it won't make an independently-wrapped FSDP module, but it will be included in the outmost FSDP module.
Therefore, according to our experience, the problem you mentioned might not be the real cause of the error you met. Have you made any other modifications to the code? Or what's your pytorch version?
Thanks for your response! I use 1 GPU to debug the code. The only modification I made is probably I define a get_trainable_params
method in ChameleonXLLMXForConditionalGeneration
class to enable only a part of parameters trainable so as to save memory. After I make the aformentioned modification, the code works fine on both 1 GPU and 8 GPUs. I wonder if making the aforementioned modification will influence the model's performances?
BTW, my pytorch version is 2.3.0.
Best regards.
During training, I found the training procedure crashes when running https://github.com/Alpha-VLLM/Lumina-mGPT/blob/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/util/ckpt.py#L91
And the error is: AssertionError: FSDP assumes model.norm.weight is in the state_dict but the state_dict only has odict_keys