This PR fixes some incompatibilities that I encountered when instantiating TSS from fam/llm/fast_inference.py with older and less powerful GPUs (e.g. Google Colab T4 GPU).
fam/llm/fast_inference_utils.py was putting the model to the device (cuda) with dtype.bfloat16 instead of using the precision parameter that contains the selected dtype (by default float16 or bfloat16 depending on the GPU architecture).
The linear layer of the Attention class in fam/llm/fast_model.py was also missing the dtype definition using the one provided in the config.
This PR fixes some incompatibilities that I encountered when instantiating
TSS
fromfam/llm/fast_inference.py
with older and less powerful GPUs (e.g. Google Colab T4 GPU).fam/llm/fast_inference_utils.py
was putting the model to thedevice
(cuda) withdtype.bfloat16
instead of using theprecision
parameter that contains the selected dtype (by defaultfloat16
orbfloat16
depending on the GPU architecture).The linear layer of the
Attention
class infam/llm/fast_model.py
was also missing the dtype definition using the one provided in the config.