The existing code defines a variable precision (here) which is then used in _load_model()here to set the dtype for the model. However, this variable was not getting passed to all of the relevant functions, namely KVCache, leading to issues when using --compile.
This PR just passes along the existing precision variable to those functions. I tested that this works by being able to generate text via successful compilation using the command:
python generate.py --compile --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model_int8.pth --prompt "Hello, my name is"
while running on a NVIDIA TITAN V with precision = torch.float16. Note that the Volta architecture GPUs don't support bfloat16, hence the desire to make code support other precision.
The existing code defines a variable
precision
(here) which is then used in_load_model()
here to set the dtype for the model. However, this variable was not getting passed to all of the relevant functions, namely KVCache, leading to issues when using--compile
.This PR just passes along the existing
precision
variable to those functions. I tested that this works by being able to generate text via successful compilation using the command:while running on a NVIDIA TITAN V with
precision = torch.float16
. Note that the Volta architecture GPUs don't support bfloat16, hence the desire to make code support other precision.