pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.37k stars 488 forks source link

Extended support for existing precision variable #24

Open ankitvgupta opened 7 months ago

ankitvgupta commented 7 months ago

The existing code defines a variable precision (here) which is then used in _load_model() here to set the dtype for the model. However, this variable was not getting passed to all of the relevant functions, namely KVCache, leading to issues when using --compile.

This PR just passes along the existing precision variable to those functions. I tested that this works by being able to generate text via successful compilation using the command:

python generate.py --compile --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model_int8.pth --prompt "Hello, my name is"

while running on a NVIDIA TITAN V with precision = torch.float16. Note that the Volta architecture GPUs don't support bfloat16, hence the desire to make code support other precision.