Closed iwanglei1 closed 2 months ago
Git 居然能发中文了!!
你怎么解决这个问题的?
chat = ChatTTS.Chat() chat.load_models()
RuntimeError Traceback (most recent call last) Cell In[2], line 2 1 chat = ChatTTS.Chat() ----> 2 chat.load_models()
File K:\TestCode\ChatTTS\ChatTTS\core.py:45, in Chat.load_models(self, source) 43 if source == 'huggingface': 44 download_path = snapshot_download(repo_id="2Noise/ChatTTS", allowpatterns=[".pt", "_.yaml"]) ---> 45 self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
File K:\TestCode\ChatTTS\ChatTTS\core.py:83, in Chat._load(self, vocos_config_path, vocos_ckpt_path, dvae_config_path, dvae_ckpt_path, gpt_config_path, gpt_ckpt_path, decoder_config_path, decoder_ckpt_path, tokenizer_path, device) 81 gpt = GPT_warpper(**cfg).to(device).eval() 82 assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' ---> 83 gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) 84 self.pretrain_models['gpt'] = gpt 85 self.logger.log(logging.INFO, 'gpt loaded.')
File D:\Anaconda\envs\torch2\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for GPT_warpper: Missing key(s) in state_dict: "head_text.weight_g", "head_text.weight_v", "head_code.0.weight_g", "head_code.0.weight_v", "head_code.1.weight_g", "head_code.1.weight_v", "head_code.2.weight_g", "head_code.2.weight_v", "head_code.3.weight_g", "head_code.3.weight_v". Unexpected key(s) in state_dict: "head_text.parametrizations.weight.original0", "head_text.parametrizations.weight.original1", "head_code.0.parametrizations.weight.original0", "head_code.0.parametrizations.weight.original1", "head_code.1.parametrizations.weight.original0", "head_code.1.parametrizations.weight.original1", "head_code.2.parametrizations.weight.original0", "head_code.2.parametrizations.weight.original1", "head_code.3.parametrizations.weight.original0", "head_code.3.parametrizations.weight.original1".
我也遇到了 急等
这个方法解决问题:
这个错误表明在加载 GPT_warpper 的状态字典时,出现了键值不匹配的问题。具体来说,缺少了一些预期的键,同时有一些意外的键。这通常是因为模型的定义与其权重文件之间存在不一致。
为了修复这个问题,你可以尝试以下几种方法:
确保模型和权重文件匹配 确保 GPT_warpper 的定义与用于训练的权重文件完全匹配。你可能需要检查权重文件的版本以及模型定义的版本。
更新模型定义 如果你的权重文件是最新的,但模型定义是旧的,请尝试更新模型定义,使其与权重文件匹配。
手动匹配键 你可以手动修改权重文件或模型定义,以确保它们的键匹配。以下是一个示例,说明如何在加载状态字典时忽略意外键和缺失键:
python 复制代码 import torch
gpt
是 GPT_warpper 实例gpt_state_dict = torch.load(gpt_ckpt_path, map_location='cpu')
filtered_state_dict = {k: v for k, v in gpt_state_dict.items() if k in gpt.state_dict()}
gpt.load_state_dict(filtered_state_dict, strict=False)
python 复制代码 D:\Program Files\Python311\Lib\site-packages\ChatTTS\core.py 这个文件修改一下def load_models函数
` def load_models( self, vocos_config_path: str = None, vocos_ckpt_path: str = None, dvae_config_path: str = None, dvae_ckpt_path: str = None, gpt_config_path: str = None, gpt_ckpt_path: str = None, decoder_config_path: str = None, decoder_ckpt_path: str = None, tokenizer_path: str = None, device: str = None ): if not device: device = select_device(4096) self.logger.log(logging.INFO, f'use {device}')
if vocos_config_path:
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
vocos.load_state_dict(torch.load(vocos_ckpt_path))
self.pretrain_models['vocos'] = vocos
self.logger.log(logging.INFO, 'vocos loaded.')
if dvae_config_path:
cfg = OmegaConf.load(dvae_config_path)
dvae = DVAE(**cfg).to(device).eval()
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
self.pretrain_models['dvae'] = dvae
self.logger.log(logging.INFO, 'dvae loaded.')
if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT_warpper(**cfg).to(device).eval()
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
gpt_state_dict = torch.load(gpt_ckpt_path, map_location='cpu')
filtered_state_dict = {k: v for k, v in gpt_state_dict.items() if k in gpt.state_dict()}
gpt.load_state_dict(filtered_state_dict, strict=False)
self.pretrain_models['gpt'] = gpt
self.logger.log(logging.INFO, 'gpt loaded.')
if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg).to(device).eval()
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
self.pretrain_models['decoder'] = decoder
self.logger.log(logging.INFO, 'decoder loaded.')
if tokenizer_path:
tokenizer = torch.load(tokenizer_path, map_location='cpu')
tokenizer.padding_side = 'left'
self.pretrain_models['tokenizer'] = tokenizer
self.logger.log(logging.INFO, 'tokenizer loaded.')
self.check_model()`
以上代码修改了 load_models 函数,在加载 gpt 模型时,过滤掉了状态字典中意外的键,并在加载状态字典时设置 strict=False,以忽略缺失的键。这应该有助于解决你遇到的问题。
This issue was closed because it has been inactive for 15 days since being marked as stale.
有解决方法吗?
chat = ChatTTS.Chat() chat.load_models()
RuntimeError Traceback (most recent call last) Cell In[2], line 2 1 chat = ChatTTS.Chat() ----> 2 chat.load_models()
File K:\TestCode\ChatTTS\ChatTTS\core.py:45, in Chat.load_models(self, source) 43 if source == 'huggingface': 44 download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=[".pt", ".yaml"]) ---> 45 self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
File K:\TestCode\ChatTTS\ChatTTS\core.py:83, in Chat._load(self, vocos_config_path, vocos_ckpt_path, dvae_config_path, dvae_ckpt_path, gpt_config_path, gpt_ckpt_path, decoder_config_path, decoder_ckpt_path, tokenizer_path, device) 81 gpt = GPT_warpper(**cfg).to(device).eval() 82 assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' ---> 83 gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) 84 self.pretrain_models['gpt'] = gpt 85 self.logger.log(logging.INFO, 'gpt loaded.')
File D:\Anaconda\envs\torch2\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for GPT_warpper: Missing key(s) in state_dict: "head_text.weight_g", "head_text.weight_v", "head_code.0.weight_g", "head_code.0.weight_v", "head_code.1.weight_g", "head_code.1.weight_v", "head_code.2.weight_g", "head_code.2.weight_v", "head_code.3.weight_g", "head_code.3.weight_v". Unexpected key(s) in state_dict: "head_text.parametrizations.weight.original0", "head_text.parametrizations.weight.original1", "head_code.0.parametrizations.weight.original0", "head_code.0.parametrizations.weight.original1", "head_code.1.parametrizations.weight.original0", "head_code.1.parametrizations.weight.original1", "head_code.2.parametrizations.weight.original0", "head_code.2.parametrizations.weight.original1", "head_code.3.parametrizations.weight.original0", "head_code.3.parametrizations.weight.original1".