princeton-vl / RAFT

BSD 3-Clause "New" or "Revised" License
3.23k stars 629 forks source link

Fixing model loading on Apple Silicon (when DEVICE is set to "mps") #177

Open jonas-eschmann opened 10 months ago

jonas-eschmann commented 10 months ago

The tensor locality has to be remapped to work with non-cuda devices. Before I was getting the following error (even after setting DEVICE='mps'):

jonas@Jonass-MacBook-Pro RAFT % python demo.py --model=models/raft-things.pth --path=demo-frames
Traceback (most recent call last):
  File "/Users/jonas/phd/courses/llvm/RAFT/demo.py", line 75, in <module>
    demo(args)
  File "/Users/jonas/phd/courses/llvm/RAFT/demo.py", line 44, in demo
    model.load_state_dict(torch.load(args.model))
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonas/venvs/sam/lib/python3.11/site-packages/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonas/venvs/sam/lib/python3.11/site-packages/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/jonas/venvs/sam/lib/python3.11/site-packages/torch/serialization.py", line 1392, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonas/venvs/sam/lib/python3.11/site-packages/torch/serialization.py", line 1366, in load_tensor
    wrap_storage=restore_location(storage, location),
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonas/venvs/sam/lib/python3.11/site-packages/torch/serialization.py", line 381, in default_restore_location
    result = fn(storage, location)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonas/venvs/sam/lib/python3.11/site-packages/torch/serialization.py", line 274, in _cuda_deserialize
    device = validate_cuda_device(location)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonas/venvs/sam/lib/python3.11/site-packages/torch/serialization.py", line 258, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.