THUDM / VisualGLM-6B

Chinese and English multimodal conversational language model | 多模态中英双语对话语言模型
Apache License 2.0
4.08k stars 416 forks source link

qlora merge lora weights error #350

Open zousss opened 5 months ago

zousss commented 5 months ago

RuntimeError Traceback (most recent call last) Cell In[9], line 4 1 from finetune_visualglm import FineTuneVisualGLMModel 2 import argparse ----> 4 model, args = FineTuneVisualGLMModel.from_pretrained('/kaggle/working/checkpoints/finetune-visualglm-6b-04-09-09-10', 5 args=argparse.Namespace( 6 fp16=True, 7 skip_init=True, 8 use_gpu_initialization=True, 9 device='cuda', 10 )) 11 model.get_mixin('lora').merge_lora() 12 args.layer_range = []

File /opt/conda/lib/python3.10/site-packages/sat/model/base_model.py:207, in BaseModel.from_pretrained(cls, name, args, home_path, url, prefix, build_only, overwrite_args, kwargs) 205 model = get_model(args, cls, kwargs) 206 if not build_only: --> 207 load_checkpoint(model, args, load_path=model_path, prefix=prefix) 208 return model, args

File /opt/conda/lib/python3.10/site-packages/sat/training/model_io.py:238, in load_checkpoint(model, args, load_path, prefix) 235 module = model 237 # only load module, other hyperparameters are just for recording. --> 238 missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False) 239 if len(unexpected_keys) > 0: 240 print_rank0( 241 f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2138, in Module.load_state_dict(self, state_dict, strict, assign) 2131 out = hook(module, incompatible_keys) 2132 assert out is None, ( 2133 "Hooks registered with register_load_state_dict_post_hook are not" 2134 "expected to return new values, if incompatible_keys need to be modified," 2135 "it should be done inplace." 2136 ) -> 2138 load(self, state_dict) 2139 del load 2141 if strict:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix) 2124 child_prefix = prefix + name + '.' 2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} -> 2126 load(child, child_state_dict, child_prefix) 2128 # Note that the hook can modify missing_keys and unexpected_keys. 2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix) 2124 child_prefix = prefix + name + '.' 2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} -> 2126 load(child, child_state_dict, child_prefix) 2128 # Note that the hook can modify missing_keys and unexpected_keys. 2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

[... skipping similar frames: Module.load_state_dict.<locals>.load at line 2126 (3 times)]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix) 2124 child_prefix = prefix + name + '.' 2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} -> 2126 load(child, child_state_dict, child_prefix) 2128 # Note that the hook can modify missing_keys and unexpected_keys. 2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2120, in Module.load_state_dict..load(module, local_state_dict, prefix) 2118 if assign: 2119 local_metadata['assign_to_params_buffers'] = assign -> 2120 module._load_from_state_dict( 2121 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 2122 for name, child in module._modules.items(): 2123 if child is not None:

File /opt/conda/lib/python3.10/site-packages/sat/model/finetune/lora2.py:47, in HackLinearNF4._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) 45 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 46 if prefix + 'weight' in statedict: ---> 47 self.weight.data.copy(state_dict[prefix+'weight']) 48 if self.weight.data.dtype == torch.uint8: 49 copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)

RuntimeError: output with shape [25165824, 1] doesn't match the broadcast shape [25165824, 0]

How can I solve this problem?