2noise / ChatTTS

A generative speech model for daily dialogue.
https://2noise.com
GNU Affero General Public License v3.0
30.73k stars 3.34k forks source link

decoder error on all_codes.masked_fill & what's the correct vesion of vector_quantize_pytorch #744

Open unbelievable3513 opened 1 week ago

unbelievable3513 commented 1 week ago

An error occurred as follows during the process of changing the default decoder to DVAE (inferring with use_decoder=False). Could it be attributed to an incompatible version of vector_quantize_pytorch==1.17.3? However, I have attempted vector-quantize-pytorch==1.16.1, vector-quantize-pytorch==1.15.5, and vector-quantize-pytorch==1.14.24.

  File "/workspace/ChatTTS/ChatTTS/model/dvae.py", line 95, in _embed
    feat = self.quantizer.get_output_from_indices(x)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/residual_fsq.py", line 248, in get_output_from_indices
    outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/residual_fsq.py", line 248, in <genexpr>
    outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/residual_fsq.py", line 134, in get_output_from_indices
    codes = self.get_codes_from_indices(indices)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/residual_fsq.py", line 120, in get_codes_from_indices
    all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [2,0,0], thread: [36,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.

environment:

av                                13.0.0
chattts                           0.0.0             /workspace/ChatTTS
gradio                            4.42.0
gradio_client                     1.3.0
nemo_text_processing              1.0.2
numba                             0.60.0
numpy                             1.26.4
pybase16384                       0.3.7
pydub                             0.25.1
pynini                            2.1.5
torch                             2.1.2
torchaudio                        2.1.2
tqdm                              4.66.5
transformers                      4.44.2
transformers-stream-generator     0.0.5
vector-quantize-pytorch           1.16.1
vocos                             0.1.0
WeTextProcessing                  1.0.3

Would you be able to offer me some suggestions, please?

fumiama commented 1 week ago

You can try the colab notebook on the README. It used vector-quantize-pytorch version 1.17.3 and got no problem when giving use_decoder=False.