Open mreso opened 2 months ago
Hey @mreso , thank you so much for flagging this! Your PR would definitely be welcomed.
Happy to create a PR to create a dtype for the precision from the string or something.
Can you clarify a bit where the kv cache would get the dtype from?
It sounds like it would be a small change. If you can run a few tests and it passes our existing ones, I think it should be easy to approve. Feel free to think about some unit test to catch it too (that would be a huge plus! :) )
Hi @felipemello1 yes, the change should be simple and probably best to use the dtype from main config. A PR could be as easy as this change in __init__ of InferenceRecipes:
if hasattr(self._quantizer, "precision"):
self._quantizer.precision = self._dtype
Which should not break anything if a different quantizer is used. With this change generation runs when I disable torch.compile but this recent addition breaks compilation of the kv cache as .item() is not supported.
...
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 527, in call_method
result = handler_method(*args, **kwargs)
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 769, in method_item
unimplemented("Tensor.item")
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 283, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item
from user code:
File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 47, in generate_next_token
logits = model(x, input_pos=input_pos)[:, -1]
File "/home/mreso/torchtune/torchtune/modules/transformer.py", line 453, in forward
h = layer(
File "/home/mreso/torchtune/torchtune/modules/transformer.py", line 99, in forward
attn_out = self.attn(self.sa_norm(x), mask=mask, input_pos=input_pos)
File "/home/mreso/torchtune/torchtune/modules/attention.py", line 266, in forward
k, v = self.kv_cache.update(input_pos, k, v)
File "/home/mreso/torchtune/torchtune/modules/kv_cache.py", line 67, in update
self.size = input_pos.max().item() + 1
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
We can skip over this error by setting
torch._dynamo.config.capture_scalar_outputs = True
But then I run into an inductor error:
return compiled_fn(runtime_args)
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1412, in __call__
return self.current_callable(inputs)
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_inductor/utils.py", line 1899, in run
return model(new_inputs)
File "/tmp/torchinductor_mreso/aw/cawku6vajsf6xwpycmumj3pifzbwrpmjnr325vmuojplt7o2jw4r.py", line 2593, in call
extern_kernels.mm(reinterpret_tensor(buf23, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf25, (4096, 1024), (1, 4096), 0), out=buf26)
RuntimeError: Expected out tensor to have dtype float, but got c10::BFloat16 instead
Currently I am not sure what that triggers. Here is a gist with the failing kernel if you want to have a look. I can try to spend some time debugging later today.
Thanks for pointing it out! Regarding the kv cache, I chatted with @pbontrager, and apparently .item() is not necessary there. We can use something like size(dim).
I read your issue paying a bit more attention, and I have a couple of questions:
Its default value is float32 and it seems to be unaffected buy the model.to(dtype) call [here]
(https://github.com/pytorch/torchtune/blob/34162c9b2b24e6e555f6a06074dc0c34431a5aa8/recipes/generate.py#L73). I guess this would be the main issue. I find it odd that you say its unaffected, since we train our models on bf16. Maybe we should try to dig more here first?
Then there is a second error:
RuntimeError: Invalid device string: 'bfloat16'
i think thats because it should be "bf16", not "bfloat16":
what do you think?
Hi @felipemello1 Thanks for looking into this again.
So the bf16 vs bfloat16 is a good catch, though this is only true for dtype in the main part of the config. Using this for the quantizer precision produces the same error as the string is never converted into a dtype when funneled from the config into the constructor of the quantizer:
File "/home/mreso/torchtune/recipes/generate.py", line 72, in _setup_model
model = self._quantizer.quantize(model)
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 1014, in quantize
state_dict = self._create_quantized_state_dict(model)
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 989, in _create_quantized_state_dict
weight.to(self.precision),
RuntimeError: Invalid device string: 'bf16'
Note that for this error I am not setting dtype but precision for the quantizer:
quantizer:
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
groupsize: 256
precision: bf16
precision is a parameter of the quantizer constructor but it does not accept a string (only dtype).
Regarding the quantizer not being affected by .to(), the quantizer replaces the linear layers with Int8DynActInt4WeightLinear which is derived from nn.Module so its parameters will be afftected by .to() but precision is just an unregistered class variable which will not change: https://github.com/pytorch/ao/blob/b523f9f9e15b6fb80d10f585d9cf45e0c5e4d10e/torchao/quantization/GPTQ.py#L927
Making this to be affected by .to() would either mean overwriting .to() or registering a parameter with nn.Module and extract the parameters dtype while calling forward here. Both not very nice options.
Probably better to either change the constructor of Int8DynActInt4WeightQuantizer to not only accept dtype for precision but strings as well and then use something like this to convert strings to dtypes. Or convert the string between reading from config and constructing the quantizer as proposed above.
@felipemello1 Turns out the precision in the quantizer was a red herring.... it actually needs to be float32 to capture the result of the matmul (see the inductor error mentioned above). I've now casted the tensors of k,v,q to the right dtype and made minor changes to make it compile. Will put the changes in a PR for further discussion.
@felipemello1 Here is the PR: #1371
Hi,
I am trying to apply the generate recipe on a quantized llama 3.1 8B model but run into the following error:
I looked a bit through the code of torchtune and torchao and found that the linear layers like k_proj are wrapped with Int8DynActInt4WeightLinear which has a precision argument that seems to get used here to determine the precision of the computation. Its default value is float32 and it seems to be unaffected buy the model.to(dtype) call here.
Thus the value thats gets used to updated the kv cache is float32 instead of bloat16. Trying to set the precision in the configuration only results in this error:
Not sure if there is a way to set the precision correctly from the config. Happy to create a PR to create a dtype for the precision from the string or something.
Steps to reproduce
quantize.yaml:
generate.yaml:
Full Logs Full error stack trace: