metavoiceio / metavoice-src

Foundational model for human-like, expressive TTS
https://themetavoice.xyz/
Apache License 2.0
3.91k stars 661 forks source link

fix: propagate `precision` correctly to enable non-bf16 inference #165

Open Icedgarr opened 5 months ago

Icedgarr commented 5 months ago

This PR fixes some incompatibilities that I encountered when instantiating TSS from fam/llm/fast_inference.py with older and less powerful GPUs (e.g. Google Colab T4 GPU).

fam/llm/fast_inference_utils.py was putting the model to the device (cuda) with dtype.bfloat16 instead of using the precision parameter that contains the selected dtype (by default float16 or bfloat16 depending on the GPU architecture).

The linear layer of the Attention class in fam/llm/fast_model.py was also missing the dtype definition using the one provided in the config.

Icedgarr commented 5 months ago

I have reverted the last commit since it was not required for this fix.