Closed dearleiii closed 6 years ago
Recommended approach for saving a model
There are two main approaches for serializing and restoring a model.
The first (recommended) saves and loads only the model parameters:
torch.save(the_model.state_dict(), PATH) Then later:
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH)) The second saves and loads the entire model:
torch.save(the_model, PATH) Then later:
the_model = torch.load(PATH) However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors
) ) cintpus: torch.Size([300, 3, 1020, 2040]) scores: torch.Size([300, 1]) Traceback (most recent call last): File "load_model_test.py", line 65, in
outputs = model1(inputs)
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in call
result = self.forward(*input, *kwargs)
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 110, in forward
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 121, in scatter
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 36, in scatter_kwargs
inputs = scatter(inputs, target_gpus, dim) if inputs else []
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 29, in scatter
return scatter_map(inputs)
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 16, in scatter_map
return list(zip(map(scatter_map, obj)))
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 14, in scatter_map
return Scatter.apply(target_gpus, None, dim, obj)
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/_functions.py", line 73, in forward
streams = [_get_stream(device) for device in ctx.target_gpus]
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/_functions.py", line 73, in
streams = [_get_stream(device) for device in ctx.target_gpus]
File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/parallel/_functions.py", line 100, in _get_stream
if _streams[device] is None:
IndexError: list index out of range
leichen@gpu-compute3$ python3 load_model_test.py