Open ProgramCrafter opened 1 year ago
Unfortunately, by their implementation, DirectML tensors do not have storage. related?: #425
Is there any way to move a model from GPU to system memory? In torch_directml specifically.
I'm trying to get Bark Audio model https://github.com/suno-ai/bark working in DirectML. Bark uses 3 models, in a sequence. But in DirectML once a model was to with .to(dml)
then I can't move it again to put it on CPU, even though it's no longer in use. This means generating audio in Bark in DirectML requires 12GB of GPU memory instead of 6GB, in regular torch.
I can completely delete the model after each step, and then reload from scratch as if it was a fresh program start. This is pretty slow obviously, is there a better way? I don't need to do inference on the model on the CPU, so I don't need to move it. I just need to say, "Clear out that model from GPU memory for now..."
My total experience with DirectML is about 2 hours 'device' to 'dml' and googling for answers.
Hello, thank you for submitting this issue. While I can't provide a timeline for resolution as the moment, please know that your feedback is valuable to us. We will follow up once we can review this issue.
Now, there is no error or warning, but tensor is not loaded properly. In the following code tensor is still on the CPU:
import torch
import torch_directml
device = torch_directml.device(torch_directml.default_device())
t = torch.load('tmp.pt', map_location=device)
print(t.device, device) # cpu privateuseone:0
Manually assigning device works as expected:
import torch
import torch_directml
device = torch_directml.device(torch_directml.default_device())
t = torch.load('tmp.pt').to(device)
print(t.device, device) # privateuseone:0 privateuseone:0
Versions:
torch==2.3.1
torch-directml==0.2.3dev240715
numpy==1.26.4
I'm attempting to load a large model and want it to load directly to GPU, as I don't have enough RAM but VRAM amount is sufficient.
I'm feeding DirectML device as
map_location
argument totorch.load
and receive the following error: