GATECH-EIC / ShiftAddLLM

ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization
Apache License 2.0
73 stars 9 forks source link

Error: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! #4

Open viraatdas opened 1 week ago

viraatdas commented 1 week ago

Running this command

CUDA_VISIBLE_DEVICES=0 python3 model/llama.py     ShiftAddLLM/Llama-2-70b-wbits2-acc

I'm seeing this error

CUDA extension not installed.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:02<00:00, 12.00it/s]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 8192)
    (layers): ModuleList(
      (0-79): 80 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=8192, out_features=8192, bias=False)
          (k_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (v_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (o_proj): Linear(in_features=8192, out_features=8192, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=8192, out_features=28672, bias=False)
          (up_proj): Linear(in_features=8192, out_features=28672, bias=False)
          (down_proj): Linear(in_features=28672, out_features=8192, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((8192,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=8192, out_features=32000, bias=False)
)
wikitext2
Evaluating ...
Traceback (most recent call last):
  File "/home/ubuntu/ShiftAddLLM/model/llama.py", line 310, in <module>
    llama_eval(model, testloader, DEV)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/ShiftAddLLM/model/llama.py", line 206, in llama_eval
    model(batch)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 977, in forward
    position_embeddings = self.rotary_emb(hidden_states, position_ids)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 209, in forward
    freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)
viraatdas commented 1 week ago

Please be wary of @Klevente12 response. The zip file he has linked seems to show an error.

Scanner of the zip file: https://www.virustotal.com/gui/file/9be8cc69033e664bb32237516e3d6904c1de3c96fc9f0a30dc4a747d927886ff

This seems to be a spam account. Please flag/delete the above comment to be safe.

viraatdas commented 1 week ago

This PR should fix it: https://github.com/GATECH-EIC/ShiftAddLLM/pull/5