Hi, apologies if this is covered, but I haven't seen a solution for this in the code or docs yet.
I've got a model that I trained using MPS (using Metal on an M1). I'd like to use that model elsewhere, but I'm having trouble loading the state dict. Python has a map_location kwarg that allows you to load a model to a specific device, but that argument does not exist in torch-rb.
Here's what I'm doing:
Torch::Backends::MPS.available? # => true
Torch.load('my-trained-model.pt') # => RuntimeError: supported devices include CPU, CUDA and HPU, however got MPS
Hi @bockets, I'm not sure how to reproduce. Are you training / saving the model in Ruby? Can you provide a minimal script for training / saving the model that causes the error?
Hi, apologies if this is covered, but I haven't seen a solution for this in the code or docs yet.
I've got a model that I trained using MPS (using Metal on an M1). I'd like to use that model elsewhere, but I'm having trouble loading the state dict. Python has a
map_location
kwarg that allows you to load a model to a specific device, but that argument does not exist in torch-rb.Here's what I'm doing:
I could also be doing this completely wrong.