[Bug] torch.isin(elements=inputs, test_elements=pad_token_id).any() TypeError: isin() received an invalid combination of arguments - got (elements=Tensor, test_elements=int, )

Describe the bug

torch.isin(elements=inputs, test_elements=pad_token_id).any() TypeError: isin() received an invalid combination of arguments - got (elements=Tensor, test_elements=int, ), but expected one of:


print("Loading model...") json_path = "D:/xtts2/config.json" xtts_checkpoint = "D:/xtts2/" config = XttsConfig() config.load_json(json_path) model = Xtts.init_from_config(config) model.load_checkpoint(config, checkpoint_dir=xtts_checkpoint,use_deepspeed=False) model.cuda() print("Computing speaker latents...") gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["../tts/reference.wav"])

chunks = model.inference_stream(
        pad_token_id=torch.tensor([1025], device=model.device),
        eos_token_id=torch.tensor([1025], device=model.device)
wav_chuncks = []
for i, chunk in enumerate(chunks):
    print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav =, dim=0)"xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)

我用自己的类,继承了StreamGenerationConfig,并且重写了update方法解决了这个问题 问题的原因是update 通过kwargs覆盖了原来的配置,为什么python没有一个像spring boot拷贝属性且不覆盖已有属性的方法。 以下是完整代码,希望可以帮到你: ` from TTS.api import TTS import torch from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig from TTS.tts.models.xtts import Xtts import torchaudio

class TokenConfig(StreamGenerationConfig): def init(self, pad_token_id, eos_token_id, kwargs): super().init(kwargs) self.pad_token_id = pad_token_id self.eos_token_id = eos_token_id def update(self,**kwargs): to_remove = [] for key, value in kwargs.items(): if hasattr(self, key) and key !='pad_token_id' and key !='eos_token_id': setattr(self, key, value) to_remove.append(key) return {}

device = "cuda" if torch.cuda.is_available() else "cpu" print("use device {}".format(device))

if name == 'main': print("Loading model...") json_path = "D:/xtts2/config.json" xtts_checkpoint = "D:/xtts2/" config = XttsConfig() config.load_json(json_path) model = Xtts.init_from_config(config) model.load_checkpoint(config, checkpoint_dir=xtts_checkpoint,use_deepspeed=False) model.cuda() print("Computing speaker latents...") gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["../tts/reference.wav"]) chunks = model.inference_stream( "今天天气真好", "zh-cn", gpt_cond_latent, speaker_embedding, generation_config=TokenConfig( pad_token_id=torch.tensor([1025], device=model.device), eos_token_id=torch.tensor([1025], device=model.device) ) ) wav_chuncks = [] for i, chunk in enumerate(chunks): print(f"Received chunk {i} of audio length {chunk.shape[-1]}") wav_chuncks.append(chunk) wav =, dim=0)"xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) `

eginhard commented 2 weeks ago

Duplicate of It should work fine if you install transformers version 4.40.2 or lower. Installing our fork (pip install coqui-tts) will also take care of that.

XPDD commented 2 weeks ago

Thank you, it can work now.