Vahe1994 / AQLM

Official Pytorch repository for Extreme Compression of Large Language Models via Additive Quantization https://arxiv.org/pdf/2401.06118.pdf and PV-Tuning: Beyond Straight-Through Estimation for Extreme LLM Compression https://arxiv.org/abs/2405.14852
Apache License 2.0
1.13k stars 173 forks source link

`RuntimeError: CUDA error: invalid argument` while running #25

Closed alex4321 closed 7 months ago

alex4321 commented 7 months ago

I have Ubuntu 23.10 system.

I installed cudatoolkit 12.1 using https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local

(since it need headers and so so I can't just install cuda through conda).

The rest of my environment

accelerate @ git+https://github.com/huggingface/accelerate.git@97d2168e5953fe7373a06c69c02c5a00a84d5344
anyio==4.2.0
aqlm @ file:///home/alex4321/Documents/AQLM/inference_lib
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.2.0
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
Brotli @ file:///work/perseverance-python-buildout/croot/brotli-split_1698805593785/work
certifi @ file:///croot/certifi_1696279375225/work/certifi
cffi @ file:///croot/cffi_1700254295673/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
comm==0.2.1
cryptography @ file:///work/perseverance-python-buildout/croot/cryptography_1698845900024/work
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
executing==2.0.1
fastjsonschema==2.19.1
filelock @ file:///work/perseverance-python-buildout/croot/filelock_1698846025262/work
fqdn==1.5.1
fsspec==2024.2.0
h11==0.14.0
httpcore==1.0.3
httpx==0.26.0
huggingface-hub==0.20.3
idna @ file:///work/perseverance-python-buildout/croot/idna_1698845632828/work
ipykernel==6.29.2
ipython==8.21.0
isoduration==20.11.0
jedi==0.19.1
Jinja2 @ file:///work/perseverance-python-buildout/croot/jinja2_1698847462642/work
json5==0.9.14
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-lsp==2.2.2
jupyter_client==8.6.0
jupyter_core==5.7.1
jupyter_server==2.12.5
jupyter_server_terminals==0.5.2
jupyterlab==4.1.1
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.3
llvmlite==0.42.0
MarkupSafe @ file:///work/perseverance-python-buildout/croot/markupsafe_1698846636000/work
matplotlib-inline==0.1.6
mistune==3.0.2
mkl-service==2.4.0
mpmath @ file:///work/perseverance-python-buildout/croot/mpmath_1698864994882/work
nbclient==0.9.0
nbconvert==7.16.0
nbformat==5.9.2
nest-asyncio==1.6.0
networkx @ file:///work/perseverance-python-buildout/croot/networkx_1698865062738/work
ninja==1.11.1.1
notebook_shim==0.2.4
numba==0.59.0
numpy @ file:///work/perseverance-python-buildout/croot/numpy_and_numpy_base_1698845160062/work/dist/numpy-1.26.0-cp312-cp312-linux_x86_64.whl#sha256=fdc35057024038070345ff9f7f47ed48ecdb21dd72461617bdadf4f5d1634fcb
overrides==7.7.0
packaging==23.2
pandocfilters==1.5.1
parso==0.8.3
pexpect==4.9.0
Pillow @ file:///work/perseverance-python-buildout/croot/pillow_1698847657722/work
platformdirs==4.2.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
Pygments==2.17.2
pyOpenSSL @ file:///work/perseverance-python-buildout/croot/pyopenssl_1698863523157/work
PySocks @ file:///work/perseverance-python-buildout/croot/pysocks_1698845478203/work
python-dateutil==2.8.2
python-json-logger==2.0.7
PyYAML @ file:///work/perseverance-python-buildout/croot/pyyaml_1698849903511/work
pyzmq==25.1.2
referencing==0.33.0
regex==2023.12.25
requests @ file:///work/perseverance-python-buildout/croot/requests_1698846321763/work
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.18.0
safetensors==0.4.2
scipy==1.12.0
Send2Trash==1.8.2
setuptools==68.0.0
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy @ file:///croot/sympy_1701397643339/work
terminado==0.18.0
tinycss2==1.2.1
tokenizers==0.15.2
torch==2.2.0
torchaudio==2.2.0
torchvision==0.17.0
tornado==6.4
tqdm==4.66.2
traitlets==5.14.1
transformers==4.37.0
triton==2.2.0
types-python-dateutil==2.8.19.20240106
typing_extensions==4.9.0
uri-template==1.3.0
urllib3 @ file:///work/perseverance-python-buildout/croot/urllib3_1698845837793/work
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
wheel==0.41.2

AQLM installed from latest github state.

Now if I try to run some code:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

llama_cpu_quantized_model = AutoModelForCausalLM.from_pretrained(
    "BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf",
    trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)
llama_gpu_quantized_model = AutoModelForCausalLM.from_pretrained(
    "BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf",
    trust_remote_code=True, torch_dtype=torch.float16, device_map="cuda:0"
).cuda()
llama_tokenizer = AutoTokenizer.from_pretrained("daryl149/llama-2-7b-hf")

output = llama_cpu_quantized_model.generate(llama_tokenizer("Test is", return_tensors="pt")["input_ids"], max_new_tokens=10)
print(llama_tokenizer.decode(output[0]))

it tells me

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Compiling AQLM numba kernel with parameters: kernel_key=(8, 4096, 4096, 2)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
    - Avoid using `tokenizers` before the fork if possible
    - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Compiling AQLM numba kernel with parameters: kernel_key=(8, 11008, 4096, 2)
Compiling AQLM numba kernel with parameters: kernel_key=(8, 4096, 11008, 2)
<s> Test is a 19999 film directed by

which I guess it more or less fine.

But:

output = llama_gpu_quantized_model.generate(llama_tokenizer("Test is", return_tensors="pt")["input_ids"].cuda(), max_new_tokens=10)
print(llama_tokenizer.decode(output[0]))

gives me

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], [line 1](vscode-notebook-cell:?execution_count=7&line=1)
----> [1](vscode-notebook-cell:?execution_count=7&line=1) output = llama_gpu_quantized_model.generate(llama_tokenizer("Test is", return_tensors="pt")["input_ids"].cuda(), max_new_tokens=10)
      [2](vscode-notebook-cell:?execution_count=7&line=2) print(llama_tokenizer.decode(output[0]))

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/utils/_contextlib.py:115](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1474](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1474), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   [1457](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1457)     return self.assisted_decoding(
   [1458](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1458)         input_ids,
   [1459](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1459)         candidate_generator=candidate_generator,
   (...)
   [1470](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1470)         **model_kwargs,
   [1471](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1471)     )
   [1472](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1472) if generation_mode == GenerationMode.GREEDY_SEARCH:
   [1473](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1473)     # 11. run greedy search
-> [1474](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1474)     return self.greedy_search(
   [1475](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1475)         input_ids,
   [1476](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1476)         logits_processor=prepared_logits_processor,
   [1477](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1477)         stopping_criteria=prepared_stopping_criteria,
   [1478](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1478)         pad_token_id=generation_config.pad_token_id,
   [1479](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1479)         eos_token_id=generation_config.eos_token_id,
   [1480](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1480)         output_scores=generation_config.output_scores,
   [1481](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1481)         return_dict_in_generate=generation_config.return_dict_in_generate,
   [1482](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1482)         synced_gpus=synced_gpus,
   [1483](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1483)         streamer=streamer,
   [1484](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1484)         **model_kwargs,
   [1485](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1485)     )
   [1487](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1487) elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
   [1488](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:1488)     if not model_kwargs["use_cache"]:

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2335](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2335), in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   [2332](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2332) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   [2334](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2334) # forward pass to get next token
-> [2335](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2335) outputs = self(
   [2336](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2336)     **model_inputs,
   [2337](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2337)     return_dict=True,
   [2338](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2338)     output_attentions=output_attentions,
   [2339](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2339)     output_hidden_states=output_hidden_states,
   [2340](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2340) )
   [2342](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2342) if synced_gpus and this_peer_finished:
   [2343](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/transformers/generation/utils.py:2343)     continue  # don't waste resources running the code we don't need

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1509](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1509)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1510](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1510) else:
-> [1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511)     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520), in Module._call_impl(self, *args, **kwargs)
   [1515](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1515) # If we don't have any hooks, we want to skip the rest of the logic in
   [1516](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1516) # this function, and just call forward.
   [1517](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1517) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1518](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1518)         or _global_backward_pre_hooks or _global_backward_hooks
   [1519](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1519)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520)     return forward_call(*args, **kwargs)
   [1522](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1522) try:
   [1523](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1523)     result = None

File [~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1195](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1195), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   [1192](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1192) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   [1194](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1194) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> [1195](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1195) outputs = self.model(
   [1196](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1196)     input_ids=input_ids,
   [1197](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1197)     attention_mask=attention_mask,
   [1198](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1198)     position_ids=position_ids,
   [1199](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1199)     past_key_values=past_key_values,
   [1200](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1200)     inputs_embeds=inputs_embeds,
   [1201](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1201)     use_cache=use_cache,
   [1202](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1202)     output_attentions=output_attentions,
   [1203](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1203)     output_hidden_states=output_hidden_states,
   [1204](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1204)     return_dict=return_dict,
   [1205](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1205) )
   [1207](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1207) hidden_states = outputs[0]
   [1208](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1208) if self.config.pretraining_tp > 1:

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1509](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1509)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1510](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1510) else:
-> [1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511)     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520), in Module._call_impl(self, *args, **kwargs)
   [1515](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1515) # If we don't have any hooks, we want to skip the rest of the logic in
   [1516](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1516) # this function, and just call forward.
   [1517](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1517) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1518](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1518)         or _global_backward_pre_hooks or _global_backward_hooks
   [1519](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1519)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520)     return forward_call(*args, **kwargs)
   [1522](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1522) try:
   [1523](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1523)     result = None

File [~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1082](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1082), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   [1072](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1072)     layer_outputs = self._gradient_checkpointing_func(
   [1073](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1073)         decoder_layer.__call__,
   [1074](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1074)         hidden_states,
   (...)
   [1079](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1079)         use_cache,
   [1080](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1080)     )
   [1081](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1081) else:
-> [1082](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1082)     layer_outputs = decoder_layer(
   [1083](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1083)         hidden_states,
   [1084](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1084)         attention_mask=attention_mask,
   [1085](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1085)         position_ids=position_ids,
   [1086](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1086)         past_key_value=past_key_values,
   [1087](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1087)         output_attentions=output_attentions,
   [1088](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1088)         use_cache=use_cache,
   [1089](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1089)     )
   [1091](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1091) hidden_states = layer_outputs[0]
   [1093](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:1093) if use_cache:

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1509](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1509)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1510](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1510) else:
-> [1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511)     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520), in Module._call_impl(self, *args, **kwargs)
   [1515](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1515) # If we don't have any hooks, we want to skip the rest of the logic in
   [1516](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1516) # this function, and just call forward.
   [1517](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1517) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1518](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1518)         or _global_backward_pre_hooks or _global_backward_hooks
   [1519](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1519)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520)     return forward_call(*args, **kwargs)
   [1522](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1522) try:
   [1523](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1523)     result = None

File [~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:810](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:810), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    [807](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:807) hidden_states = self.input_layernorm(hidden_states)
    [809](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:809) # Self Attention
--> [810](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:810) hidden_states, self_attn_weights, present_key_value = self.self_attn(
    [811](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:811)     hidden_states=hidden_states,
    [812](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:812)     attention_mask=attention_mask,
    [813](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:813)     position_ids=position_ids,
    [814](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:814)     past_key_value=past_key_value,
    [815](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:815)     output_attentions=output_attentions,
    [816](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:816)     use_cache=use_cache,
    [817](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:817)     **kwargs,
    [818](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:818) )
    [819](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:819) hidden_states = residual + hidden_states
    [821](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:821) # Fully Connected

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1509](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1509)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1510](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1510) else:
-> [1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511)     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520), in Module._call_impl(self, *args, **kwargs)
   [1515](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1515) # If we don't have any hooks, we want to skip the rest of the logic in
   [1516](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1516) # this function, and just call forward.
   [1517](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1517) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1518](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1518)         or _global_backward_pre_hooks or _global_backward_hooks
   [1519](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1519)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520)     return forward_call(*args, **kwargs)
   [1522](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1522) try:
   [1523](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1523)     result = None

File [~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:705](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:705), in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    [694](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:694)     return super().forward(
    [695](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:695)         hidden_states=hidden_states,
    [696](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:696)         attention_mask=attention_mask,
   (...)
    [700](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:700)         use_cache=use_cache,
    [701](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:701)     )
    [703](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:703) bsz, q_len, _ = hidden_states.size()
--> [705](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:705) query_states = self.q_proj(hidden_states)
    [706](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:706) key_states = self.k_proj(hidden_states)
    [707](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/.cache/huggingface/modules/transformers_modules/BlackSamorez/Llama-2-7b-AQLM-2Bit-2x8-hf/2df1b7a5cbb2a8b584eade2de5c2b4975072a644/modeling_llama_aqlm.py:707) value_states = self.v_proj(hidden_states)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1509](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1509)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1510](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1510) else:
-> [1511](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1511)     return self._call_impl(*args, **kwargs)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520), in Module._call_impl(self, *args, **kwargs)
   [1515](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1515) # If we don't have any hooks, we want to skip the rest of the logic in
   [1516](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1516) # this function, and just call forward.
   [1517](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1517) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1518](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1518)         or _global_backward_pre_hooks or _global_backward_hooks
   [1519](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1519)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1520](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1520)     return forward_call(*args, **kwargs)
   [1522](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1522) try:
   [1523](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py:1523)     result = None

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:65](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:65), in QuantizedLinear.forward(self, input)
     [59](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:59) if (
     [60](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:60)     not input.is_cuda
     [61](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:61)     and self.codebook_size == 256
     [62](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:62)     and self.codes.shape[0] == self.out_features [/](https://file+.vscode-resource.vscode-cdn.net/)[/](https://file+.vscode-resource.vscode-cdn.net/) self.out_group_size
     [63](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:63) ):
     [64](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:64)     self.codes.data = torch.permute(self.codes.data, (1, 0, 2)).contiguous()  #  TODO: fix this thing
---> [65](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference.py:65) return forward_pass_quantized_linear(input, self.codes, self.codebooks, self.scales, self.bias)

File [~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:31](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:31), in forward_pass_quantized_linear(input, codes, codebooks, scales, bias)
     [26](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:26)     from .cuda_kernel import CUDA_KERNEL
     [28](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:28)     assert (
     [29](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:29)         input.dtype == torch.float16
     [30](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:30)     ), f"please load the model with `torch_dtype=torch.float16`, as {input.dtype} is not supported on GPU yet"
---> [31](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:31)     return CUDA_KERNEL.code2x8_matmat(input, codes, codebooks, scales) + (bias if bias is not None else 0)
     [32](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:32) case (True, _, _, _, _):
     [33](https://file+.vscode-resource.vscode-cdn.net/home/alex4321/Documents/AQLM/~/anaconda3/envs/llms/lib/python3.12/site-packages/aqlm/inference_kernels/kernel_selector.py:33)     from .triton_kernel import triton_matmul

RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

p.s. by the way - it is kinda offtopic, but I don't get it:

BlackSamorez commented 7 months ago

Hi! All take a more detailed look tomorrow.

In the meantime, what GPU are you using? We've observed that the 2x8 kernel might fail on older GPUs that don't have enough cache. We haven't properly determined which GPUs are affected, though.

alex4321 commented 7 months ago

2080Ti

And I commented code to figure out the first place error occures is

  cudaFuncSetAttribute(
    Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
  );

@BlackSamorez

alex4321 commented 7 months ago

So

void  code2x8_matvec_cuda(
  const void* __restrict__ A,
  const void* __restrict__ B,
        void* __restrict__ C,
  const void* __restrict__ codebook,
  int prob_m,
  int prob_k
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  int shared = 16 * (2 * 256 * 8 + 32 * 9);
  /*
  cudaFuncSetAttribute(
    Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
  );
  Code2x8MatVec<<<blocks, threads, shared>>>(
    (const int4*) A,
    (const int4*) B,
    (int4*) C,
    (const int4*) codebook,
    prob_m,
    prob_k
  );
  */
}

output garbage yet works

void  code2x8_matvec_cuda(
  const void* __restrict__ A,
  const void* __restrict__ B,
        void* __restrict__ C,
  const void* __restrict__ codebook,
  int prob_m,
  int prob_k
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  int shared = 16 * (2 * 256 * 8 + 32 * 9);
  cudaFuncSetAttribute(
    Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
  );
  /*
  Code2x8MatVec<<<blocks, threads, shared>>>(
    (const int4*) A,
    (const int4*) B,
    (int4*) C,
    (const int4*) codebook,
    prob_m,
    prob_k
  );
  */
}

does not

alex4321 commented 7 months ago

Well, it seems it is probably my issue

According to the CUDA C Programming Guide, compute capability 7. x devices allow a single thread block to dynamically allocate shared memory up to 64 KB on Turing.

And 2080ti is 7.5

So my maximum shared memory is 64KB which means 65536 bytes while kernel:

  int shared = 16 * (2 * 256 * 8 + 32 * 9);
  cudaFuncSetAttribute(
    Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
  );

tries to set 16 * (2 * 256 * 8 + 32 * 9) = 70144 bytes shared memory size

BlackSamorez commented 7 months ago

Looks like there isn't much we can do about it then. The speedup is mostly there due to the codebooks fitting into shared memory. Concerning your second question regarding the axes order, it's, indeed, different for the Numba kernels because they profit from a different memory layout compared to all the other kernels. You can see the tensors being transposed once during inference. The code is a mess, and we're hoping to improve both the speed and readability by implementing a proper one-time kernel selector in the near future.

alex4321 commented 7 months ago

Thanks for pointing the issue source, anyway.