zhihanyue / ts2vec

A universal time series representation learning framework
MIT License
619 stars 148 forks source link

Pretrained model loading problem #34

Open AlanConstantine opened 1 year ago

AlanConstantine commented 1 year ago

I meet an issue when I try to load a pre-trained TS2Vec model

Here is the code:

# train.py
model = TS2Vec(
        input_dims=input_dims,
        device=device,
        batch_size=batch_size,
        output_dims=output_dims
    )
loss_log = model.fit(X, verbose=verbose, n_epochs=n_epochs)
model.save(r'./models/mv_10_1_5_model.pkl')
# predict.py
model = TS2Vec(
    input_dims=input_dims,
    device=device,
    batch_size=batch_size,
    output_dims=output_dims
)
model.load(r'./models/mv_10_1_5_model.pkl')

And the bug

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_5364\1224328800.py in ()
     10     output_dims=output_dims
     11 )
---> 12 model.load('./models/mv_10_1_5_model.pkl')

[E:\Workplace\ts2vec\ts2vec.py](file:///E:/Workplace/ts2vec/ts2vec.py) in load(self, fn)
    315             fn (str): filename.
    316         '''
--> 317         state_dict = torch.load(fn, map_location=self.device)
    318         self.net.load_state_dict(state_dict)
    319 

[e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in load(f, map_location, pickle_module, **pickle_load_args)
    710                     opened_file.seek(orig_position)
    711                     return torch.jit.load(opened_file)
--> 712                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    713         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    714 

[e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
   1044     unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1045     unpickler.persistent_load = persistent_load
-> 1046     result = unpickler.load()
   1047 
   1048     torch._utils._validate_loaded_sparse_tensors()

[e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in persistent_load(saved_id)
   1014         if key not in loaded_storages:
   1015             nbytes = numel * torch._utils._element_size(dtype)
-> 1016             load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
   1017 
   1018         return loaded_storages[key]

[e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in load_tensor(dtype, numel, key, location)
    999         # stop wrapping with _TypedStorage
   1000         loaded_storages[key] = torch.storage._TypedStorage(
-> 1001             wrap_storage=restore_location(storage, location),
   1002             dtype=dtype)
   1003 

[e:\Software\anaconda3\lib\site-packages\torch\serialization.py](file:///E:/Software/anaconda3/lib/site-packages/torch/serialization.py) in restore_location(storage, location)
    974     else:
    975         def restore_location(storage, location):
--> 976             result = map_location(storage, location)
    977             if result is None:
    978                 result = default_restore_location(storage, location)

TypeError: 'int' object is not callable

Do you have an example for saving and loading the pre-trained model? I would very appreciate if you could help address this issue!

AlanConstantine commented 1 year ago

I load successfully after removing the map_location=self.device, here is the solusion:

state_dict = torch.load(path)
model.net.load_state_dict(state_dict)