microsoft / DirectML

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers, including all DirectX 12-capable GPUs from vendors such as AMD, Intel, NVIDIA, and Qualcomm.
MIT License
2.22k stars 296 forks source link

There is no way to load tensor directly to GPU #466

Open ProgramCrafter opened 1 year ago

ProgramCrafter commented 1 year ago

I'm attempting to load a large model and want it to load directly to GPU, as I don't have enough RAM but VRAM amount is sufficient.

I'm feeding DirectML device as map_location argument to torch.load and receive the following error:

Traceback (most recent call last):
  File "C:\Users\***\AppData\Roaming\Python\Python39\site-packages\transformers\modeling_utils.py", line 446, in load_state_dict
    return torch.load(checkpoint_file, map_location="cpu")
  File "E:\LLM-Dolly\.models\databricks\dolly-v2-12b\main.py", line 44, in <lambda>
    torch.load = lambda f, map_location, *a, **k: old_tl(f, self.device, *a, **k)
  File "C:\Users\***\AppData\Roaming\Python\Python39\site-packages\torch\serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "C:\Users\***\AppData\Roaming\Python\Python39\site-packages\torch\serialization.py", line 1172, in _load
    result = unpickler.load()
  File "C:\Users\***\AppData\Roaming\Python\Python39\site-packages\torch\serialization.py", line 1142, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  File "C:\Users\***\AppData\Roaming\Python\Python39\site-packages\torch\serialization.py", line 1116, in load_tensor
    wrap_storage=restore_location(storage, location),
  File "C:\Users\***\AppData\Roaming\Python\Python39\site-packages\torch\serialization.py", line 1086, in restore_location
    return default_restore_location(storage, str(map_location))
  File "C:\Users\***\AppData\Roaming\Python\Python39\site-packages\torch\serialization.py", line 220, in default_restore_location
    raise RuntimeError("don't know how to restore data location of "
RuntimeError: don't know how to restore data location of torch.storage.UntypedStorage (tagged with privateuseone:0)
lshqqytiger commented 1 year ago

Unfortunately, by their implementation, DirectML tensors do not have storage. related?: #425

JonathanFly commented 1 year ago

Is there any way to move a model from GPU to system memory? In torch_directml specifically.

I'm trying to get Bark Audio model https://github.com/suno-ai/bark working in DirectML. Bark uses 3 models, in a sequence. But in DirectML once a model was to with .to(dml) then I can't move it again to put it on CPU, even though it's no longer in use. This means generating audio in Bark in DirectML requires 12GB of GPU memory instead of 6GB, in regular torch.

I can completely delete the model after each step, and then reload from scratch as if it was a fresh program start. This is pretty slow obviously, is there a better way? I don't need to do inference on the model on the CPU, so I don't need to move it. I just need to say, "Clear out that model from GPU memory for now..."

My total experience with DirectML is about 2 hours 'device' to 'dml' and googling for answers.

Adele101 commented 1 year ago

Hello, thank you for submitting this issue. While I can't provide a timeline for resolution as the moment, please know that your feedback is valuable to us. We will follow up once we can review this issue.

VadimShabashov commented 3 months ago

Now, there is no error or warning, but tensor is not loaded properly. In the following code tensor is still on the CPU:

import torch
import torch_directml

device = torch_directml.device(torch_directml.default_device())
t = torch.load('tmp.pt', map_location=device)
print(t.device, device)  # cpu privateuseone:0

Manually assigning device works as expected:

import torch
import torch_directml

device = torch_directml.device(torch_directml.default_device())
t = torch.load('tmp.pt').to(device)
print(t.device, device)  # privateuseone:0 privateuseone:0

Versions:

torch==2.3.1
torch-directml==0.2.3dev240715
numpy==1.26.4