lenML / ChatTTS-Forge

🍦 ChatTTS-Forge is a project developed around TTS generation model, implementing an API Server and a Gradio-based WebUI.
https://huggingface.co/spaces/lenML/ChatTTS-Forge
GNU Affero General Public License v3.0
585 stars 71 forks source link

[ISSUE] 加载模型时,dtype为什么写死为float32? #94

Closed wenyangchou closed 1 month ago

wenyangchou commented 1 month ago

阅读 README.md 和 dependencies.md

检索 issue 和 discussion

检查 Forge 版本

你的issues

https://github.com/lenML/ChatTTS-Forge/blob/main/modules/models.py#L34

def load_chat_tts_in_thread():
    global chat_tts
    if chat_tts:
        return

    logger.info("Loading ChatTTS models")
    chat_tts = ChatTTS.Chat()
    device = devices.get_device_for("chattts")
    dtype = devices.dtype
    chat_tts.load(
        compile=config.runtime_env_vars.compile,
        use_flash_attn=config.runtime_env_vars.flash_attn,
        source="custom",
        custom_path="./models/ChatTTS",
        device=device,
        dtype=dtype,
        # dtype_vocos=devices.dtype_vocos,
        # dtype_dvae=devices.dtype_dvae,
        # dtype_gpt=devices.dtype_gpt,
        # dtype_decoder=devices.dtype_decoder,
    )

这边的dtype为什么会注释掉,写死为float32?是有什么坑吗,这样写flash_attn用不起来啊

wenyangchou commented 1 month ago

transformer驱动起来的时候有个bug,不会正常的加载dtype。是这个原因?

zhzLuke96 commented 1 month ago

请提供报错信息,不然无法定位问题