2noise / ChatTTS

A generative speech model for daily dialogue.
https://2noise.com
GNU Affero General Public License v3.0
31.21k stars 3.39k forks source link

Something is Wrong,It could be that I replaced from torch.nn.utils.parametrizations import weight_norm with from torch.nn.utils import weight_norm #19

Closed iwanglei1 closed 2 months ago

iwanglei1 commented 4 months ago

chat = ChatTTS.Chat() chat.load_models()


RuntimeError Traceback (most recent call last) Cell In[2], line 2 1 chat = ChatTTS.Chat() ----> 2 chat.load_models()

File K:\TestCode\ChatTTS\ChatTTS\core.py:45, in Chat.load_models(self, source) 43 if source == 'huggingface': 44 download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=[".pt", ".yaml"]) ---> 45 self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})

File K:\TestCode\ChatTTS\ChatTTS\core.py:83, in Chat._load(self, vocos_config_path, vocos_ckpt_path, dvae_config_path, dvae_ckpt_path, gpt_config_path, gpt_ckpt_path, decoder_config_path, decoder_ckpt_path, tokenizer_path, device) 81 gpt = GPT_warpper(**cfg).to(device).eval() 82 assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' ---> 83 gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) 84 self.pretrain_models['gpt'] = gpt 85 self.logger.log(logging.INFO, 'gpt loaded.')

File D:\Anaconda\envs\torch2\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for GPT_warpper: Missing key(s) in state_dict: "head_text.weight_g", "head_text.weight_v", "head_code.0.weight_g", "head_code.0.weight_v", "head_code.1.weight_g", "head_code.1.weight_v", "head_code.2.weight_g", "head_code.2.weight_v", "head_code.3.weight_g", "head_code.3.weight_v". Unexpected key(s) in state_dict: "head_text.parametrizations.weight.original0", "head_text.parametrizations.weight.original1", "head_code.0.parametrizations.weight.original0", "head_code.0.parametrizations.weight.original1", "head_code.1.parametrizations.weight.original0", "head_code.1.parametrizations.weight.original1", "head_code.2.parametrizations.weight.original0", "head_code.2.parametrizations.weight.original1", "head_code.3.parametrizations.weight.original0", "head_code.3.parametrizations.weight.original1".

iwanglei1 commented 4 months ago

Git 居然能发中文了!!

XuWink commented 4 months ago

你怎么解决这个问题的?

chat = ChatTTS.Chat() chat.load_models()

RuntimeError Traceback (most recent call last) Cell In[2], line 2 1 chat = ChatTTS.Chat() ----> 2 chat.load_models()

File K:\TestCode\ChatTTS\ChatTTS\core.py:45, in Chat.load_models(self, source) 43 if source == 'huggingface': 44 download_path = snapshot_download(repo_id="2Noise/ChatTTS", allowpatterns=[".pt", "_.yaml"]) ---> 45 self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})

File K:\TestCode\ChatTTS\ChatTTS\core.py:83, in Chat._load(self, vocos_config_path, vocos_ckpt_path, dvae_config_path, dvae_ckpt_path, gpt_config_path, gpt_ckpt_path, decoder_config_path, decoder_ckpt_path, tokenizer_path, device) 81 gpt = GPT_warpper(**cfg).to(device).eval() 82 assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' ---> 83 gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) 84 self.pretrain_models['gpt'] = gpt 85 self.logger.log(logging.INFO, 'gpt loaded.')

File D:\Anaconda\envs\torch2\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for GPT_warpper: Missing key(s) in state_dict: "head_text.weight_g", "head_text.weight_v", "head_code.0.weight_g", "head_code.0.weight_v", "head_code.1.weight_g", "head_code.1.weight_v", "head_code.2.weight_g", "head_code.2.weight_v", "head_code.3.weight_g", "head_code.3.weight_v". Unexpected key(s) in state_dict: "head_text.parametrizations.weight.original0", "head_text.parametrizations.weight.original1", "head_code.0.parametrizations.weight.original0", "head_code.0.parametrizations.weight.original1", "head_code.1.parametrizations.weight.original0", "head_code.1.parametrizations.weight.original1", "head_code.2.parametrizations.weight.original0", "head_code.2.parametrizations.weight.original1", "head_code.3.parametrizations.weight.original0", "head_code.3.parametrizations.weight.original1".

oldlee11 commented 4 months ago

我也遇到了 急等

jxyk2007 commented 4 months ago

这个方法解决问题:

这个错误表明在加载 GPT_warpper 的状态字典时,出现了键值不匹配的问题。具体来说,缺少了一些预期的键,同时有一些意外的键。这通常是因为模型的定义与其权重文件之间存在不一致。

为了修复这个问题,你可以尝试以下几种方法:

  1. 确保模型和权重文件匹配 确保 GPT_warpper 的定义与用于训练的权重文件完全匹配。你可能需要检查权重文件的版本以及模型定义的版本。

  2. 更新模型定义 如果你的权重文件是最新的,但模型定义是旧的,请尝试更新模型定义,使其与权重文件匹配。

  3. 手动匹配键 你可以手动修改权重文件或模型定义,以确保它们的键匹配。以下是一个示例,说明如何在加载状态字典时忽略意外键和缺失键:

python 复制代码 import torch

假设 gpt 是 GPT_warpper 实例

gpt_state_dict = torch.load(gpt_ckpt_path, map_location='cpu')

过滤掉状态字典中的意外键

filtered_state_dict = {k: v for k, v in gpt_state_dict.items() if k in gpt.state_dict()}

加载过滤后的状态字典

gpt.load_state_dict(filtered_state_dict, strict=False)

  1. 修改 load_models 函数 你可以在 load_models 函数中添加类似的代码,以便在加载状态字典时忽略意外键和缺失键。

python 复制代码 D:\Program Files\Python311\Lib\site-packages\ChatTTS\core.py 这个文件修改一下def load_models函数

` def load_models( self, vocos_config_path: str = None, vocos_ckpt_path: str = None, dvae_config_path: str = None, dvae_ckpt_path: str = None, gpt_config_path: str = None, gpt_ckpt_path: str = None, decoder_config_path: str = None, decoder_ckpt_path: str = None, tokenizer_path: str = None, device: str = None ): if not device: device = select_device(4096) self.logger.log(logging.INFO, f'use {device}')

if vocos_config_path:
    vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
    assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
    vocos.load_state_dict(torch.load(vocos_ckpt_path))
    self.pretrain_models['vocos'] = vocos
    self.logger.log(logging.INFO, 'vocos loaded.')

if dvae_config_path:
    cfg = OmegaConf.load(dvae_config_path)
    dvae = DVAE(**cfg).to(device).eval()
    assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
    dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
    self.pretrain_models['dvae'] = dvae
    self.logger.log(logging.INFO, 'dvae loaded.')

if gpt_config_path:
    cfg = OmegaConf.load(gpt_config_path)
    gpt = GPT_warpper(**cfg).to(device).eval()
    assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
    gpt_state_dict = torch.load(gpt_ckpt_path, map_location='cpu')
    filtered_state_dict = {k: v for k, v in gpt_state_dict.items() if k in gpt.state_dict()}
    gpt.load_state_dict(filtered_state_dict, strict=False)
    self.pretrain_models['gpt'] = gpt
    self.logger.log(logging.INFO, 'gpt loaded.')

if decoder_config_path:
    cfg = OmegaConf.load(decoder_config_path)
    decoder = DVAE(**cfg).to(device).eval()
    assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
    decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
    self.pretrain_models['decoder'] = decoder
    self.logger.log(logging.INFO, 'decoder loaded.')

if tokenizer_path:
    tokenizer = torch.load(tokenizer_path, map_location='cpu')
    tokenizer.padding_side = 'left'
    self.pretrain_models['tokenizer'] = tokenizer
    self.logger.log(logging.INFO, 'tokenizer loaded.')

self.check_model()`

以上代码修改了 load_models 函数,在加载 gpt 模型时,过滤掉了状态字典中意外的键,并在加载状态字典时设置 strict=False,以忽略缺失的键。这应该有助于解决你遇到的问题。

github-actions[bot] commented 2 months ago

This issue was closed because it has been inactive for 15 days since being marked as stale.

hqm19 commented 2 months ago

有解决方法吗?