ankane / torch.rb

Deep learning for Ruby, powered by LibTorch
Other
704 stars 30 forks source link

Cannot load trained model to MPS device #49

Closed bockets closed 7 months ago

bockets commented 8 months ago

Hi, apologies if this is covered, but I haven't seen a solution for this in the code or docs yet.

I've got a model that I trained using MPS (using Metal on an M1). I'd like to use that model elsewhere, but I'm having trouble loading the state dict. Python has a map_location kwarg that allows you to load a model to a specific device, but that argument does not exist in torch-rb.

Here's what I'm doing:

Torch::Backends::MPS.available? # => true
Torch.load('my-trained-model.pt') # => RuntimeError: supported devices include CPU, CUDA and HPU, however got MPS

I could also be doing this completely wrong.

ankane commented 8 months ago

Hi @bockets, I'm not sure how to reproduce. Are you training / saving the model in Ruby? Can you provide a minimal script for training / saving the model that causes the error?