pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.24k stars 416 forks source link

RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source. #1349

Open mreso opened 2 months ago

mreso commented 2 months ago

Hi,

I am trying to apply the generate recipe on a quantized llama 3.1 8B model but run into the following error:

...
  File "/home/mreso/torchtune/torchtune/modules/attention.py", line 211, in forward
    k, v = self.kv_cache.update(input_pos, k, v)
  File "/home/mreso/torchtune/torchtune/modules/kv_cache.py", line 69, in update
    k_out[:, :, input_pos] = k_val
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.

I looked a bit through the code of torchtune and torchao and found that the linear layers like k_proj are wrapped with Int8DynActInt4WeightLinear which has a precision argument that seems to get used here to determine the precision of the computation. Its default value is float32 and it seems to be unaffected buy the model.to(dtype) call here.

> /home/mreso/torchtune/generate.py(77)_setup_model()
-> model.load_state_dict(model_state_dict)
(Pdb) print(model.layers[0].attn.q_proj.precision)
torch.float32

Thus the value thats gets used to updated the kv cache is float32 instead of bloat16. Trying to set the precision in the configuration only results in this error:

  File "generate.py", line 74, in _setup_model
    model = self._quantizer.quantize(model)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 1014, in quantize
    state_dict = self._create_quantized_state_dict(model)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 989, in _create_quantized_state_dict
    weight.to(self.precision),
RuntimeError: Invalid device string: 'bfloat16'

Not sure if there is a way to set the precision correctly from the config. Happy to create a PR to create a dtype for the precision from the string or something.

Steps to reproduce

tune run quantize --config quantize.yaml

quantize.yaml:

# Config for QuantizationRecipe in quantize.py
#
# To launch, run the following command from root torchtune directory:
#    tune run quantize --config quantization

#
# Model arguments
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: Meta-Llama-3-8B-Instruct/
  checkpoint_files: [
    meta_model_0.pt
  ]
  recipe_checkpoint: null
  output_dir: ./quantized
  model_type: LLAMA3

device: cuda
dtype: bf16
seed: 1234
max_seq_len: 2048

quantizer:
  _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
tune run generate --config generate.yaml

generate.yaml:

# Config for running the InferenceRecipe in generate.py to generate output from an LLM
#
# To launch, run the following command from root torchtune directory:
#    tune run generate --config generation

# Model arguments
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
  max_seq_len: 2048

checkpointer:
  _component_: torchtune.utils.FullModelTorchTuneCheckpointer
  checkpoint_dir: quantized/
  checkpoint_files: [
    meta_model_0-8da4w.pt
  ]
  output_dir: model-output
  model_type: LLAMA3

device: cuda
dtype: bf16

seed: 1234

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /home/mreso/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/8c22764a7e3675c50d4c7c9a4edb474456022b16/original/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: 'Amanda: I baked  cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-)'
# instruct_template: torchtune.data.SummarizeTemplate
#chat_format: null
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
enable_kv_cache: True

quantizer:
  _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
#  precision: bfloat16 # uncommenting this results into second error

Full Logs Full error stack trace:

INFO:torchtune.utils.logging:Starting compilation to improve generation performance ...
Traceback (most recent call last):
  File "/home/mreso/.conda/envs/tune/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/home/mreso/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/mreso/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/mreso/torchtune/torchtune/_cli/run.py", line 179, in _run_cmd
    self._run_single_device(args)
  File "/home/mreso/torchtune/torchtune/_cli/run.py", line 93, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/home/mreso/.conda/envs/tune/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/mreso/.conda/envs/tune/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/mreso/.conda/envs/tune/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/mreso/torchtune/recipes/generate.py", line 204, in <module>
    sys.exit(main())
  File "/home/mreso/torchtune/torchtune/config/_parse.py", line 50, in wrapper
    sys.exit(recipe_main(conf))
  File "/home/mreso/torchtune/recipes/generate.py", line 200, in main
    recipe.generate(cfg=cfg)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mreso/torchtune/recipes/generate.py", line 151, in generate
    _ = utils.generate(
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 122, in generate
    tokens = generate_next_token(
  File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 47, in generate_next_token
    logits = model(x, input_pos=input_pos)[:, -1]
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/torchtune/torchtune/modules/transformer.py", line 243, in forward
    h = layer(h, mask=mask, input_pos=input_pos)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/torchtune/torchtune/modules/transformer.py", line 72, in forward
    attn_out = self.attn(self.sa_norm(x), mask=mask, input_pos=input_pos)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mreso/torchtune/torchtune/modules/attention.py", line 211, in forward
    k, v = self.kv_cache.update(input_pos, k, v)
  File "/home/mreso/torchtune/torchtune/modules/kv_cache.py", line 69, in update
    k_out[:, :, input_pos] = k_val
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.
felipemello1 commented 2 months ago

Hey @mreso , thank you so much for flagging this! Your PR would definitely be welcomed.

Happy to create a PR to create a dtype for the precision from the string or something.

Can you clarify a bit where the kv cache would get the dtype from?

It sounds like it would be a small change. If you can run a few tests and it passes our existing ones, I think it should be easy to approve. Feel free to think about some unit test to catch it too (that would be a huge plus! :) )

mreso commented 2 months ago

Hi @felipemello1 yes, the change should be simple and probably best to use the dtype from main config. A PR could be as easy as this change in __init__ of InferenceRecipes:

if hasattr(self._quantizer, "precision"):
    self._quantizer.precision = self._dtype

Which should not break anything if a different quantizer is used. With this change generation runs when I disable torch.compile but this recent addition breaks compilation of the kv cache as .item() is not supported.

...
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 769, in method_item
    unimplemented("Tensor.item")
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 283, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

from user code:
   File "/home/mreso/torchtune/torchtune/utils/_generation.py", line 47, in generate_next_token
    logits = model(x, input_pos=input_pos)[:, -1]
  File "/home/mreso/torchtune/torchtune/modules/transformer.py", line 453, in forward
    h = layer(
  File "/home/mreso/torchtune/torchtune/modules/transformer.py", line 99, in forward
    attn_out = self.attn(self.sa_norm(x), mask=mask, input_pos=input_pos)
  File "/home/mreso/torchtune/torchtune/modules/attention.py", line 266, in forward
    k, v = self.kv_cache.update(input_pos, k, v)
  File "/home/mreso/torchtune/torchtune/modules/kv_cache.py", line 67, in update
    self.size = input_pos.max().item() + 1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

We can skip over this error by setting

torch._dynamo.config.capture_scalar_outputs = True

But then I run into an inductor error:

    return compiled_fn(runtime_args)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1412, in __call__
    return self.current_callable(inputs)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/_inductor/utils.py", line 1899, in run
    return model(new_inputs)
  File "/tmp/torchinductor_mreso/aw/cawku6vajsf6xwpycmumj3pifzbwrpmjnr325vmuojplt7o2jw4r.py", line 2593, in call
    extern_kernels.mm(reinterpret_tensor(buf23, (1, 4096), (4096, 1), 0), reinterpret_tensor(buf25, (4096, 1024), (1, 4096), 0), out=buf26)
RuntimeError: Expected out tensor to have dtype float, but got c10::BFloat16 instead

Currently I am not sure what that triggers. Here is a gist with the failing kernel if you want to have a look. I can try to spend some time debugging later today.

felipemello1 commented 2 months ago

Thanks for pointing it out! Regarding the kv cache, I chatted with @pbontrager, and apparently .item() is not necessary there. We can use something like size(dim).

I read your issue paying a bit more attention, and I have a couple of questions:

Its default value is float32 and it seems to be unaffected buy the model.to(dtype) call [here]

(https://github.com/pytorch/torchtune/blob/34162c9b2b24e6e555f6a06074dc0c34431a5aa8/recipes/generate.py#L73). I guess this would be the main issue. I find it odd that you say its unaffected, since we train our models on bf16. Maybe we should try to dig more here first?

Then there is a second error:

RuntimeError: Invalid device string: 'bfloat16'

i think thats because it should be "bf16", not "bfloat16":

what do you think?

mreso commented 2 months ago

Hi @felipemello1 Thanks for looking into this again.

So the bf16 vs bfloat16 is a good catch, though this is only true for dtype in the main part of the config. Using this for the quantizer precision produces the same error as the string is never converted into a dtype when funneled from the config into the constructor of the quantizer:

  File "/home/mreso/torchtune/recipes/generate.py", line 72, in _setup_model
    model = self._quantizer.quantize(model)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 1014, in quantize
    state_dict = self._create_quantized_state_dict(model)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mreso/.conda/envs/tune/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 989, in _create_quantized_state_dict
    weight.to(self.precision),
RuntimeError: Invalid device string: 'bf16'

Note that for this error I am not setting dtype but precision for the quantizer:

quantizer:
  _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
  precision: bf16

precision is a parameter of the quantizer constructor but it does not accept a string (only dtype).

Regarding the quantizer not being affected by .to(), the quantizer replaces the linear layers with Int8DynActInt4WeightLinear which is derived from nn.Module so its parameters will be afftected by .to() but precision is just an unregistered class variable which will not change: https://github.com/pytorch/ao/blob/b523f9f9e15b6fb80d10f585d9cf45e0c5e4d10e/torchao/quantization/GPTQ.py#L927

Making this to be affected by .to() would either mean overwriting .to() or registering a parameter with nn.Module and extract the parameters dtype while calling forward here. Both not very nice options.

Probably better to either change the constructor of Int8DynActInt4WeightQuantizer to not only accept dtype for precision but strings as well and then use something like this to convert strings to dtypes. Or convert the string between reading from config and constructing the quantizer as proposed above.

mreso commented 2 months ago

@felipemello1 Turns out the precision in the quantizer was a red herring.... it actually needs to be float32 to capture the result of the matmul (see the inductor error mentioned above). I've now casted the tensors of k,v,q to the right dtype and made minor changes to make it compile. Will put the changes in a PR for further discussion.

mreso commented 2 months ago

@felipemello1 Here is the PR: #1371