File /opt/conda/lib/python3.10/site-packages/sat/model/base_model.py:207, in BaseModel.from_pretrained(cls, name, args, home_path, url, prefix, build_only, overwrite_args, kwargs)
205 model = get_model(args, cls, kwargs)
206 if not build_only:
--> 207 load_checkpoint(model, args, load_path=model_path, prefix=prefix)
208 return model, args
File /opt/conda/lib/python3.10/site-packages/sat/training/model_io.py:238, in load_checkpoint(model, args, load_path, prefix)
235 module = model
237 # only load module, other hyperparameters are just for recording.
--> 238 missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
239 if len(unexpected_keys) > 0:
240 print_rank0(
241 f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2138, in Module.load_state_dict(self, state_dict, strict, assign)
2131 out = hook(module, incompatible_keys)
2132 assert out is None, (
2133 "Hooks registered with register_load_state_dict_post_hook are not"
2134 "expected to return new values, if incompatible_keys need to be modified,"
2135 "it should be done inplace."
2136 )
-> 2138 load(self, state_dict)
2139 del load
2141 if strict:
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
[... skipping similar frames: Module.load_state_dict.<locals>.load at line 2126 (3 times)]
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2120, in Module.load_state_dict..load(module, local_state_dict, prefix)
2118 if assign:
2119 local_metadata['assign_to_params_buffers'] = assign
-> 2120 module._load_from_state_dict(
2121 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
2122 for name, child in module._modules.items():
2123 if child is not None:
File /opt/conda/lib/python3.10/site-packages/sat/model/finetune/lora2.py:47, in HackLinearNF4._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
45 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
46 if prefix + 'weight' in statedict:
---> 47 self.weight.data.copy(state_dict[prefix+'weight'])
48 if self.weight.data.dtype == torch.uint8:
49 copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)
RuntimeError: output with shape [25165824, 1] doesn't match the broadcast shape [25165824, 0]
RuntimeError Traceback (most recent call last) Cell In[9], line 4 1 from finetune_visualglm import FineTuneVisualGLMModel 2 import argparse ----> 4 model, args = FineTuneVisualGLMModel.from_pretrained('/kaggle/working/checkpoints/finetune-visualglm-6b-04-09-09-10', 5 args=argparse.Namespace( 6 fp16=True, 7 skip_init=True, 8 use_gpu_initialization=True, 9 device='cuda', 10 )) 11 model.get_mixin('lora').merge_lora() 12 args.layer_range = []
File /opt/conda/lib/python3.10/site-packages/sat/model/base_model.py:207, in BaseModel.from_pretrained(cls, name, args, home_path, url, prefix, build_only, overwrite_args, kwargs) 205 model = get_model(args, cls, kwargs) 206 if not build_only: --> 207 load_checkpoint(model, args, load_path=model_path, prefix=prefix) 208 return model, args
File /opt/conda/lib/python3.10/site-packages/sat/training/model_io.py:238, in load_checkpoint(model, args, load_path, prefix) 235 module = model 237 # only load module, other hyperparameters are just for recording. --> 238 missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False) 239 if len(unexpected_keys) > 0: 240 print_rank0( 241 f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2138, in Module.load_state_dict(self, state_dict, strict, assign) 2131 out = hook(module, incompatible_keys) 2132 assert out is None, ( 2133 "Hooks registered with
register_load_state_dict_post_hook
are not" 2134 "expected to return new values, if incompatible_keys need to be modified," 2135 "it should be done inplace." 2136 ) -> 2138 load(self, state_dict) 2139 del load 2141 if strict:File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2120, in Module.load_state_dict..load(module, local_state_dict, prefix)
2118 if assign:
2119 local_metadata['assign_to_params_buffers'] = assign
-> 2120 module._load_from_state_dict(
2121 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
2122 for name, child in module._modules.items():
2123 if child is not None:
File /opt/conda/lib/python3.10/site-packages/sat/model/finetune/lora2.py:47, in HackLinearNF4._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) 45 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 46 if prefix + 'weight' in statedict: ---> 47 self.weight.data.copy(state_dict[prefix+'weight']) 48 if self.weight.data.dtype == torch.uint8: 49 copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)
RuntimeError: output with shape [25165824, 1] doesn't match the broadcast shape [25165824, 0]
How can I solve this problem?