researchmm / TTSR

[CVPR'20] TTSR: Learning Texture Transformer Network for Image Super-Resolution
MIT License
765 stars 115 forks source link

running on CPU only fails #23

Closed mtedaldi closed 3 years ago

mtedaldi commented 3 years ago

Using the switch --cpu TRUE to run on a machine without a nvidia GPU

After the download the software now errors out with

marcot@AEQ128:~/build/TTSR$ sh test.sh
[2020-10-22 14:27:59,255] - [trainer.py file line:48] - INFO: load_model_path: /home/marcot/model/TTSR.pt
Traceback (most recent call last):
  File "main.py", line 35, in <module>
    t.load(model_path=args.model_path)
  File "/home/marcot/build/TTSR/trainer.py", line 50, in load
    model_state_dict_save = {k:v for k,v in torch.load(model_path).items()}
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 585, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 765, in _legacy_load
    result = unpickler.load()
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 721, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 174, in default_restore_location
    result = fn(storage, location)
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 150, in _cuda_deserialize
    device = validate_cuda_device(location)
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 134, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

So it seems that I should be able to fix this myself. I'll submit as soon as I have a solution :)

Python version:

marcot@AEQ128:~/build/TTSR$ python --version
Python 3.7.3
mylyu commented 3 years ago

use torch.load with map_location=torch.device('cpu') just use torch.load with map_location=torch.device('cpu') torch.load(model_path, map_location=torch.device('cpu'))

FuzhiYang commented 3 years ago

@mtedaldi We fixed the bug when running on CPU.