Open AlanConstantine opened 1 year ago
I meet an issue when I try to load a pre-trained TS2Vec model
Here is the code:
# train.py model = TS2Vec( input_dims=input_dims, device=device, batch_size=batch_size, output_dims=output_dims ) loss_log = model.fit(X, verbose=verbose, n_epochs=n_epochs) model.save(r'./models/mv_10_1_5_model.pkl')
# predict.py model = TS2Vec( input_dims=input_dims, device=device, batch_size=batch_size, output_dims=output_dims ) model.load(r'./models/mv_10_1_5_model.pkl')
And the bug
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) ~\AppData\Local\Temp\ipykernel_5364\1224328800.py in () 10 output_dims=output_dims 11 ) ---> 12 model.load('./models/mv_10_1_5_model.pkl') [E:\Workplace\ts2vec\ts2vec.py](file:///E:/Workplace/ts2vec/ts2vec.py) in load(self, fn) 315 fn (str): filename. 316 ''' --> 317 state_dict = torch.load(fn, map_location=self.device) 318 self.net.load_state_dict(state_dict) 319 [e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in load(f, map_location, pickle_module, **pickle_load_args) 710 opened_file.seek(orig_position) 711 return torch.jit.load(opened_file) --> 712 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) 714 [e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args) 1044 unpickler = UnpicklerWrapper(data_file, **pickle_load_args) 1045 unpickler.persistent_load = persistent_load -> 1046 result = unpickler.load() 1047 1048 torch._utils._validate_loaded_sparse_tensors() [e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in persistent_load(saved_id) 1014 if key not in loaded_storages: 1015 nbytes = numel * torch._utils._element_size(dtype) -> 1016 load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) 1017 1018 return loaded_storages[key] [e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in load_tensor(dtype, numel, key, location) 999 # stop wrapping with _TypedStorage 1000 loaded_storages[key] = torch.storage._TypedStorage( -> 1001 wrap_storage=restore_location(storage, location), 1002 dtype=dtype) 1003 [e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in restore_location(storage, location) 974 else: 975 def restore_location(storage, location): --> 976 result = map_location(storage, location) 977 if result is None: 978 result = default_restore_location(storage, location) TypeError: 'int' object is not callable
Do you have an example for saving and loading the pre-trained model? I would very appreciate if you could help address this issue!
I load successfully after removing the map_location=self.device, here is the solusion:
map_location=self.device
state_dict = torch.load(path) model.net.load_state_dict(state_dict)
I meet an issue when I try to load a pre-trained TS2Vec model
Here is the code:
And the bug
Do you have an example for saving and loading the pre-trained model? I would very appreciate if you could help address this issue!