dvmazur / mixtral-offloading

Run Mixtral-8x7B models in Colab or consumer desktops
MIT License
2.29k stars 227 forks source link

Run on second GPU (torch.device("cuda:1")) #24

Open imabot2 opened 9 months ago

imabot2 commented 9 months ago

Hi, you did an awesome work ! I ran your code in an RTX3090 with offload_per_layer = 0 : Awesome !!!

I noticed that when I change the device for my second GPU device = torch.device("cuda:1"), the model is properly loaded in the GPU memory, but inference does not work:

Traceback (most recent call last):
  File "/home/philippe/tmp/mixtral2/main.py", line 112, in <module>
    result = model.generate(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/generation/utils.py", line 1764, in generate
    return self.sample(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/generation/utils.py", line 2861, in sample
    outputs = self(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1213, in forward
    outputs = self.model(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1081, in forward
    layer_outputs = decoder_layer(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 797, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 305, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/custom_layers.py", line 50, in forward
    return self.forward_triton(x)
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/custom_layers.py", line 80, in forward_triton
    output = fn(
  File "/home/philippe/tmp/mixtral2/mixtral-offloading/src/triton_kernels.py", line 172, in triton_matmul4_transpose
    matmul4_kernel_transpose[grid](
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in run
    ret = self.fn.run(
  File "/home/philippe/tmp/mixtral2/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

I can't figure out what's wrong, any idea?

Soumadip-Saha commented 9 months ago

Was this model working with offload_per_layer = 3? I was trying to use it on V100 in google colab but faced an issue with Triton. Most likely this is a version issue with Triton. If you are using v2.2.0 then you have to downgrade. You can refer to the issue I have raised #25. Test if that works.