Closed VolodymyrChapman closed 1 year ago
nice catch. in other code we use this approach
have you already seen it?
would this or that approach be more ideal?
Hi Andrew! Yeah, I came across this solution a few hours after posting (D'Oh!). So, that's a really good question which appears much simpler than it is. According to the documentation (https://pytorch.org/docs/stable/generated/torch.load.html):
torch.load() uses Python’s unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the map_location argument.
According to that, the two solutions should be identical in first mapping to CPU then GPU (or staying on CPU, if user requested). But, to check whether any hidden overheads existed, I ran a little token QA model loading task (below), refreshing the kernel between Jupyter cells to eliminate effects of any jit compilation etc.:
Surprisingly, the solution you proposed (top cell) was ~7 times faster! Will modify my proposed pull request and thank you for the tip! Best wishes, V
Wow, 7 times! Thanks for running the quick experiment.
A quick tip, you can use the magic commands in jupyter to do timings for you in a less code-intensive way : )
e.g., %%time, %%timeit
Neat - ta muchly!
I use my QA projects across devices and a problem encountered was deserialization if the GPU/device used for creating the baseline model was different to the device used for further training from baseline. Specifically, where I had trained baseline on a device with multiple GPUs and was retraining on a laptop with a single GPU. The error message encountered was identical to the one outlined here: https://stackoverflow.com/questions/56369030/runtimeerror-attempting-to-deserialize-object-on-a-cuda-device where users of PyTorch attempted to retrain models using CPU instead of GPU. In my case, I had used GPU ID 1 in a multiGPU system for baseline training and was trying to use GPU 0 in retraining on a single GPU system. All other software (conda env, Ubuntu 20.04, NVidia drivers etc.) was identical between systems. Changing GPU ID in the config.ini file was not sufficient to resolve the error. The PyTorch deserialization error was traced to line 169 of this file (train_model.py) in the STDERROR, corresponding with the torch.load() line. As outlined in the above stackoverflow, the solution is through unambiguous assignment of the device to load the model to, using the map_location argument of the PyTorch load() function. By retrieving the device variable one line earlier and unambiguously mapping to the desired device using the map_location argument in line 169, this error was resolved with no other changes in behaviour.