Camb-ai / MARS5-TTS

MARS5 speech model (TTS) from CAMB.AI
https://www.camb.ai
GNU Affero General Public License v3.0
1.37k stars 95 forks source link

Colab Demo fails to run. #13

Closed kvrban closed 1 week ago

kvrban commented 2 weeks ago

The issue:

RuntimeError                              Traceback (most recent call last)

[<ipython-input-7-2d05018561f0>](https://localhost:8080/#) in <cell line: 6>()
      4                       top_k=100, temperature=0.7, freq_penalty=3)
      5 
----> 6 ar_codes, wav_out = mars5.tts("The quick brown rat.", wav, 
      7           ref_transcript,
      8           cfg=cfg)

13 frames

[~/.cache/torch/hub/Camb-ai_mars5-tts_master/mars5/nn_future.py](https://localhost:8080/#) in forward(self, x, freqs_cis, positions, mask, cache)
    249             scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
    250             scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
--> 251             cache.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
    252             cache.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])
    253 

RuntimeError: scatter(): Expected self.dtype to be equal to src.dtypeRuntimeError                              Traceback (most recent call last)

[<ipython-input-7-2d05018561f0>](https://localhost:8080/#) in <cell line: 6>()
      4                       top_k=100, temperature=0.7, freq_penalty=3)
      5 
----> 6 ar_codes, wav_out = mars5.tts("The quick brown rat.", wav, 
      7           ref_transcript,
      8           cfg=cfg)

13 frames

[~/.cache/torch/hub/Camb-ai_mars5-tts_master/mars5/nn_future.py](https://localhost:8080/#) in forward(self, x, freqs_cis, positions, mask, cache)
    249             scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
    250             scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
--> 251             cache.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
    252             cache.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])
    253 

RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype
akshhack commented 2 weeks ago

hey @kvrban thanks for pointing it out; we'll take a look and update when it's fixed.

RF5 commented 2 weeks ago

Hi @kvrban , can you try with a T4 GPU on colab? It would appear this is a CPU-only bug with autocast. We'll have a fix for CPU-only inference and a few other CPU-friendly fixes soon.

kvrban commented 2 weeks ago

Yeah, on T4 GPU i does run.

kvrban commented 2 weeks ago

https://github.com/Camb-ai/MARS5-TTS/assets/33060804/48e66002-a3e7-47a2-8843-d610fbedf918

but Synthesized output audio is just a hum sound. (had to convert the orginal wav to a mp4 to uload here)

popo0293 commented 1 week ago

mars5_demo.mp4 but Synthesized output audio is just a hum sound. (had to convert the orginal wav to a mp4 to uload here)

try a different set of params:

cfg = config_class(deep_clone=deep_clone, rep_penalty_window=100, top_p=0.8, temperature=1.0, freq_penalty=3)

origin-s20 commented 1 week ago

The issue:

RuntimeError                              Traceback (most recent call last)

[<ipython-input-7-2d05018561f0>](https://localhost:8080/#) in <cell line: 6>()
      4                       top_k=100, temperature=0.7, freq_penalty=3)
      5 
----> 6 ar_codes, wav_out = mars5.tts("The quick brown rat.", wav, 
      7           ref_transcript,
      8           cfg=cfg)

13 frames

[~/.cache/torch/hub/Camb-ai_mars5-tts_master/mars5/nn_future.py](https://localhost:8080/#) in forward(self, x, freqs_cis, positions, mask, cache)
    249             scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
    250             scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
--> 251             cache.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
    252             cache.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])
    253 

RuntimeError: scatter(): Expected self.dtype to be equal to src.dtypeRuntimeError                              Traceback (most recent call last)

[<ipython-input-7-2d05018561f0>](https://localhost:8080/#) in <cell line: 6>()
      4                       top_k=100, temperature=0.7, freq_penalty=3)
      5 
----> 6 ar_codes, wav_out = mars5.tts("The quick brown rat.", wav, 
      7           ref_transcript,
      8           cfg=cfg)

13 frames

[~/.cache/torch/hub/Camb-ai_mars5-tts_master/mars5/nn_future.py](https://localhost:8080/#) in forward(self, x, freqs_cis, positions, mask, cache)
    249             scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
    250             scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
--> 251             cache.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
    252             cache.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])
    253 

RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype

Hi @kvrban, I'm facing a similar issue on a MacBook M3 Pro. Is there a fix for this?

pieterscholtz commented 1 week ago

@kvrban @origin-s20 We've merged a fix for this. For it to take effect, you may need to delete your torch hub cache before trying it again, e.g.: rm -rf ~/.cache/torch/hub/Camb-ai_mars5-tts_master Or simply add force_reload=True to torch.hub.load call. Please note that this is a CPU only bug and that inference on CPU will be quite slow.

origin-s20 commented 1 week ago

@kvrban @origin-s20 We've merged a fix for this. For it to take effect, you may need to delete your torch hub cache before trying it again, e.g.: rm -rf ~/.cache/torch/hub/Camb-ai_mars5-tts_master Or simply add force_reload=True to torch.hub.load call. Please note that this is a CPU only bug and that inference on CPU will be quite slow.

@pieterscholtz thanks it worked!