pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.2k stars 407 forks source link

LinearActivationQuantizedTensor dispatch error when model quantized by QAT generate #1775

Closed elfisworking closed 2 weeks ago

elfisworking commented 2 weeks ago

Now, i quantize llama3-8b model using QAT. When I tried model inference, I encountered the following error. Logs:

une run generate --config llama3_generation_config.yaml
2024-10-09:08:06:20,408 INFO     [_logging.py:101] Running InferenceRecipe with resolved config:

chat_format: null
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: /QAT/output/llama3-8B/
  checkpoint_files:
  - meta_model_2-8da4w.pt
  model_type: LLAMA3
  output_dir: /QAT/output/llama3-8B/
device: cuda
dtype: bf16
enable_kv_cache: true
instruct_template: null
max_new_tokens: 300
model:
  _component_: torchtune.models.llama3.llama3_8b
prompt: Tell me a joke?
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 42
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /QAT/Meta-Llama-3-8B/original/tokenizer.model
top_k: 300

2024-10-09:08:06:20,812 DEBUG    [seed.py:60] Setting manual seed to local seed 42. Local seed is seed + rank = 42 + 0
Traceback (most recent call last):
  File "/usr/local/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/usr/local/lib/python3.10/dist-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/usr/local/lib/python3.10/dist-packages/torchtune/_cli/run.py", line 196, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/usr/local/lib/python3.10/dist-packages/torchtune/_cli/run.py", line 102, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/usr/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/usr/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/recipes/generate.py", line 211, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
  File "/usr/local/lib/python3.10/dist-packages/recipes/generate.py", line 206, in main
    recipe.setup(cfg=cfg)
  File "/usr/local/lib/python3.10/dist-packages/recipes/generate.py", line 55, in setup
    self._model = self._setup_model(
  File "/usr/local/lib/python3.10/dist-packages/recipes/generate.py", line 73, in _setup_model
    model.load_state_dict(model_state_dict)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
        While copying the parameter named "layers.0.attn.q_proj.weight", whose dimensions in the model are torch.Size([4096, 4096]) and whose dimensions in the checkpoint are torch.Size([4096, 4096]), an exception occurred : ('LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function
: aten.copy_.default',).
        While copying the parameter named "layers.0.attn.k_proj.weight", whose dimensions in the model are torch.Size([1024, 4096]) and whose dimensions in the checkpoint are torch.Size([1024, 4096]), an exception occurred : ('LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function
: aten.copy_.default',).
        While copying the parameter named "layers.0.attn.v_proj.weight", whose dimensions in the model are torch.Size([1024, 4096]) and whose dimensions in the checkpoint are torch.Size([1024, 4096]), an exception occurred : ('LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function
: aten.copy_.default',).
        While copying the parameter named "layers.0.attn.output_proj.weight", whose dimensions in the model are torch.Size([4096, 4096]) and whose dimensions in the checkpoint are torch.Size([4096, 4096]), an exception occurred : ('LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/fun
ction: aten.copy_.default',).
        While copying the parameter named "layers.0.mlp.w1.weight", whose dimensions in the model are torch.Size([14336, 4096]) and whose dimensions in the checkpoint are torch.Size([14336, 4096]), an exception occurred : ('LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: a
ten.copy_.default',).
        While copying the parameter named "layers.0.mlp.w2.weight", whose dimensions in the model are torch.Size([4096, 14336]) and whose 

generation.yaml is

#config for running the InferenceRecipe in generate.py to generate output from an LLM
#
# To launch, run the following command from root torchtune directory:
#    tune run generate --config generation

# Model arguments
model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: /QAT/output/llama3-8B/
  checkpoint_files: [
    meta_model_2-8da4w.pt
  ]
  output_dir: /QAT/output/llama3-8B/
  model_type: LLAMA3

device: cuda
dtype: bf16
seed: 42

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /QAT/Meta-Llama-3-8B/original/tokenizer.model
  max_seq_len: null

# Generation arguments; defaults taken from gpt-fast
prompt: "Tell me a joke?"
instruct_template: null
chat_format: null
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

enable_kv_cache: True

quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256

Anyone can help me?? thanks very much

SalmanMohammadi commented 2 weeks ago

I'm also getting this error. Could you help me out here and share some details about your environment please?

1) Which hardware are you running on (MPS, CUDA)? If CUDA, which GPU? 2) Could you share your PyTorch and torchtune versions?

elfisworking commented 2 weeks ago

device: nvidia A100 torchtune:0.3.1 torchao:0.5.0 torch:2.4.0 torchvision:0.19.0 model:llama3-8B i got quantized model following this link (https://pytorch.org/torchtune/main/tutorials/qat_finetune.html )

elfisworking commented 2 weeks ago

quantization.yaml

model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /QAT/Meta-Llama-3-8B/
  checkpoint_files: [
    meta_model_2.pt
  ]
  recipe_checkpoint: null
  output_dir: /QAT/output/llama3-8B/
  model_type: LLAMA3

device: cuda
dtype: bf16
seed: 42

quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256

train.yaml

tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /QAT/Meta-Llama-3-8B/original/tokenizer.model
  max_seq_len: null

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  source: parquet
  data_files: /QAT/dataset/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
seed: 42
shuffle: True

# Model Arguments
model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /QAT/Meta-Llama-3-8B/original/
  checkpoint_files: [
    consolidated.00.pth
  ]
  recipe_checkpoint: null
  output_dir: /QAT/output/llama3-8B
  model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 16
epochs: 3

# QAT arguments
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256

optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5
  fused: True
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}
output_dir: /QAT/Meta-Llama-3-8B/finetune-logs
log_every_n_steps: 1
log_peak_memory_stats: False
SalmanMohammadi commented 2 weeks ago

@elfisworking Similar to my other comment, let me know if this works for you when you get a chance.

elfisworking commented 1 week ago

ok! i will try

elfisworking commented 1 week ago

@SalmanMohammadi hello, i try new branch code for my model after QAT quantization. And i get this log

2024-10-10:06:34:56,531 INFO     [_logging.py:101] Running InferenceRecipe with resolved config:

chat_format: null
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: /QAT/output/llama3-8B/
  checkpoint_files:
  - meta_model_2-8da4w.pt
  model_type: LLAMA3
  output_dir: /QAT/output/llama3-8B/
device: cuda
dtype: bf16
enable_kv_cache: true
instruct_template: null
max_new_tokens: 300
model:
  _component_: torchtune.models.llama3.llama3_8b
prompt: Tell me a joke?
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 42
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /QAT/Meta-Llama-3-8B/original/tokenizer.model
top_k: 300

2024-10-10:06:34:57,026 DEBUG    [seed.py:60] Setting manual seed to local seed 42. Local seed is seed + rank = 42 + 0
2024-10-10:06:35:07,106 INFO     [generate.py:97] Model is initialized with precision torch.bfloat16.
2024-10-10:06:35:08,726 INFO     [generate.py:168] Starting compilation to improve generation performance ...
Traceback (most recent call last):
  File "/usr/local/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/QAT/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/QAT/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/QAT/torchtune/torchtune/_cli/run.py", line 208, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/QAT/torchtune/torchtune/_cli/run.py", line 102, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/usr/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/usr/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/QAT/torchtune/recipes/generate.py", line 229, in <module>
    sys.exit(main())
  File "/QAT/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
  File "/QAT/torchtune/recipes/generate.py", line 225, in main
    recipe.generate(cfg=cfg)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/QAT/torchtune/recipes/generate.py", line 173, in generate
    _ = generation.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/QAT/torchtune/torchtune/generation/_generation.py", line 366, in generate
    tokens, logits = custom_generate_next_token(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1512, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/lazy.py", line 132, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1512, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1512, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1459, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/nn_module.py", line 409, in call_function
    return wrap_fx_proxy(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1713, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1798, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1908, in run_node
    return nnmodule(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
  File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 372, in _dispatch__torch_function__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 355, in wrapper
    return func(f, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 102, in _
    return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
  File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 73, in _quantized_linear_op
    return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
  File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 372, in _dispatch__torch_function__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 355, in wrapper
    return func(f, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 1500, in _
    return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__model___layers_0_attn_q_proj(*(FakeTensor(..., device='cuda:0', size=(1, 1, 4096), dtype=torch.bfloat16),), **{}):
'FakeTensor' object has no attribute '_quantized_linear_op'

from user code:
   File "/QAT/torchtune/torchtune/generation/_generation.py", line 102, in generate_next_token
    logits = model(x, input_pos=input_pos, mask=mask)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/QAT/torchtune/torchtune/modules/transformer.py", line 599, in forward
    h = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/QAT/torchtune/torchtune/modules/transformer.py", line 114, in forward
    attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/QAT/torchtune/torchtune/modules/attention.py", line 229, in forward
    q = self.q_proj(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

it says 'FakeTensor' object has no attribute '_quantized_linear_op'

elfisworking commented 1 week ago

software version:

torch                             2.4.0
torchao                           0.7.0.dev20241010+cu121
torchtune                         0.0.0                   /QAT/torchtune # from main branch
torchvision                       0.19.0
elfisworking commented 1 week ago

should i new a issue? @SalmanMohammadi

SalmanMohammadi commented 1 week ago

should i new a issue? @SalmanMohammadi

That would be really helpful, thank you!

elfisworking commented 1 week ago

I commented on the code below and found ok recipes/generate.py 167-184 lines

        if self._quantization_mode is not None:
            logger.info("Starting compilation to improve generation performance ...")
            custom_generate_next_token = torch.compile(
                generation.generate_next_token, mode="max-autotune", fullgraph=True
            )
            t0 = time.perf_counter()
            _ = generation.generate(
                model=self._model,
                prompt=prompt,
                max_generated_tokens=2,
                temperature=cfg.temperature,
                top_k=cfg.top_k,
                stop_tokens=self._tokenizer.stop_tokens,
                custom_generate_next_token=custom_generate_next_token,
            )
            t = time.perf_counter() - t0
            logger.info(f"Warmup run for quantized model takes: {t:.02f} sec")
            self._model.reset_caches()

This error may be caused by torch.compile. The model generate is

root@nx-zhongwei-4:/QAT/test# tune run generate --config llama3_generation_config.yaml
INFO:torchtune.utils._logging:Running InferenceRecipe with resolved config:

chat_format: null
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: /QAT/output/llama3-8B/
  checkpoint_files:
  - meta_model_2-8da4w.pt
  model_type: LLAMA3
  output_dir: /QAT/output/llama3-8B/
device: cuda
dtype: bf16
enable_kv_cache: false
instruct_template: null
max_new_tokens: 300
model:
  _component_: torchtune.models.llama3.llama3_8b
prompt: Tell me a joke?
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 42
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /QAT/Meta-Llama-3-8B/original/tokenizer.model
top_k: 300

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 42. Local seed is seed + rank = 42 + 0
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Tell me a joke? A funny story?
Q: What did the fish say when it swam into a wall? A: Dam!
Q: What did the volcano say to the other volcano? A: Let's lava flow!
Q: Why did the chicken cross the playground? A: To get to the other slides!
Q: What did the tree say to the acorn? A: Come on, let's go nuts!
Q: What did the giraffe say when it got its first pair of glasses? A: Wow, I can finally see the leaves on the trees!
Q: What did the spider do on the internet? A: Found websites!
Q: What do you call a snowman party? A: An igloo!
Q: What did the mummy scientist say? A: I'm a big fan of decomposition!
Q: What did the cloud say to the lightning? A: I'm feeling struck!
Q: What did the volcano say when it erupted? A: I'm gonna let my lava flow!
Q: What did the giraffe say when it got its first pair of glasses? A: Wow, I can finally see the leaves on the trees!
Q: What did the spider do on the internet? A: Found websites!
Q: What do you call a snowman party? A: An igloo!
Q: What did the mummy scientist say? A: I'm a big fan of decomposition!
Q: What did the cloud say to
INFO:torchtune.utils._logging:Time for inference: 65.43 sec total, 4.59 tokens/sec
INFO:torchtune.utils._logging:Bandwidth achieved: 73.95 GB/s
INFO:torchtune.utils._logging:Memory used: 17.41 GB
SalmanMohammadi commented 1 week ago

Yeah it looks like a compile issue which is strange since I tested it with compile when I landed the fix.

SalmanMohammadi commented 1 week ago

Could you try upgrade your torch to a nightly version?

elfisworking commented 1 week ago

has created a new issue and i will try nightly torch. Thanks