galatolofederico / vanilla-llama

Plain pytorch implementation of LLaMA
GNU General Public License v3.0
189 stars 31 forks source link

Enable the use of PyTorch 'mps' device for inference on Apple Silicon #5

Open chris-hatton opened 1 year ago

chris-hatton commented 1 year ago

The LLaMA.cpp project enables LLaMA inference on Apple Silicon devices by using CPU, but faster inference should be possible by supporting the M1/Pro/Max GPU onvanilla-llama, given that PyTorch is now M1 compatible using the 'mps' device.

I'm new to Python but my observations:

In both generation.py and model.py there are uses of the function.cuda() which can be replaced with

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
...
.to(device)

When attempting to run example.py after this; it's the Accelerate framework which throws an error with: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' - something is trying to use cpu instead of mps.

I wonder if this is because the call into accelerate is load_checkpoint_and_dispatch with auto provided as the device map - is PyTorch preferring cpu over mps here for some reason. Edit: This

anentropic commented 1 year ago

Related:

when I try to run the convert.py script I get:

  File "vanilla-llama/llama/model.py", line 99, in __init__
    ).cuda()
  File "vanilla-llama/lib/python3.10/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
anentropic commented 1 year ago

I tried getting past the OP issue (RuntimeError: "addmm_impl_cpu_" not implemented for 'Half') by passing device_map=None

found and fixed this issue https://github.com/huggingface/accelerate/pull/1297 along the way

accelerate docs say MPS backend will be used by default when available (apart from the caveat that their device map code doesn't support it yet) so I was hoping that would happen when device_map=None

But then I get this puzzling error:

vanilla-llama/llama/generation.py", line 121, in sample_top_p next_token = torch.multinomial(probs_sort, num_samples=1) NotImplementedError: Could not run 'aten::multinomial' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::multinomial' is only available for these backends: [CPU, MPS, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

I couldn't find much by googling for this, not helped by the fact that PyTorch itself seems to be part of Meta (Facebook) now

There is this https://pytorch.org/docs/stable/generated/torch.Tensor.is_meta.html and https://pytorch.org/torchdistx/latest/fake_tensor.html

Fake tensors, similar to meta tensors, carry no data; however, unlike meta tensors which report meta as their device, fake tensors act as if they were allocated on a real device.

Not sure why we are on the meta device in this line of code?

It sounds like maybe device_map=None has left us without a device

Adding an explicit device map in LLaMAInference like:

        device = torch.device("cpu")
        if torch.has_cuda:
            device = torch.device("cuda")
        elif torch.has_mps:
            device = torch.device('mps')

        if device_map is None:
            modules = (
                "transformer",
                "tok_embeddings",
                "layers",
                "norm",
                "output",
            )
            device_map = {module: device for module in modules}

...this gets further!

Now I get this error:

vanilla-llama/llama/model.py", line 62, in apply_rotary_emb
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
NotImplementedError: The operator 'aten::view_as_complex' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Finally I tried with the PYTORCH_ENABLE_MPS_FALLBACK=1 flag suggested and got:

vanilla-llama/llama/model.py:63: UserWarning: 0The operator aten::view_ascomplex appears to be a view operator, but it has no implementation for the backend "mps:0". View operators don't support falling back to run on the CPU, since the tensor's storage cannot be shared across devices. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/CPUFallback.cpp:181.) xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) libc++abi: terminating due to uncaught exception of type c10::TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype. Exception raised from getMPSScalarType at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/OperationUtils.mm:91 (most recent call first): frame #0: at::native::mps::getMPSScalarType(c10::ScalarType) + 180 (0x134721278 in libtorch_cpu.dylib)

which starts to look like running on MPS device is a dead-end

chris-hatton commented 1 year ago

Some great sleuthing there @anentropic 👌 that last error is mentioned in this thread

anentropic commented 1 year ago

Others have got this working it seems:

https://github.com/jankais3r/LLaMA_MPS