QData / spacetimeformer

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."
https://arxiv.org/abs/2109.12218
MIT License
808 stars 191 forks source link

load_from_checkpoint error #13

Closed SingleXuu closed 2 years ago

SingleXuu commented 3 years ago

`model = spacetimeformer_model.Spacetimeformer_Forecaster(d_x=6, d_y=6) model.load_from_checkpoint(check_point) data_module, inv_scaler, null_val = create_dset() trainer = pl.Trainer()

trainer.test(model=model, datamodule=data_module)`

there is an error as RuntimeError: Error(s) in loading state_dict for Spacetimeformer_Forecaster: Unexpected key(s) in state_dict:

could you tell me how to solve it? thank you very much!

jakegrigsby commented 2 years ago

I am not well equipped to debug this because the results used in the paper are based on live logging statistics, and there were very few (if any) times I actually needed to save and load a model from disk. I included that feature very early in the development process, mostly because pytorch lightning does a lot of the work for us. If I had to guess, I would say that you are probably missing kwargs in the model construction (d_model, d_ff, normalization info) that allow the model to build the correct architecture with random weights and then replace them from disk.

vinh-cao commented 2 years ago

using this code I get the same error. There are lots of shape missmatches.

import train

CHECKPOINT = 'spacetimeformer/data/stf_model_checkpoints/spatiotemporal_toy2_613732368/spatiotemporal_toy2epoch=45-val/loss=0.05.ckpt'
config = ['spacetimeformer', 'toy2', '--run_name', 'spatiotemporal_toy2',
          '--d_model', 100, '--d_ff', 400, '--enc_layers', 4, '--dec_layers', 4,
          '--gpus','2','3', '--batch_size', 32, '--start_token_len', 4,
          '--n_heads', 4,'--grad_clip_norm', 1, '--early_stopping', '--trials', 1]

parser = train.create_parser('spacetimeformer', 'toy2')  # I modified the function to enable passing model and dset.
config = parser.parse_args([str(i) if type(i) == int else i for i in config])
forecaster = train.create_model(config)
forecaster.load_from_checkpoint(CHECKPOINT)

Errors:

Traceback (most recent call last):
  File "/home/caq/.pycharm_helpers/pydev/pydevd.py", line 1448, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/caq/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/spacetimeformer/spacetimeformer/demo.py", line 37, in <module>
    forecaster.load_from_checkpoint(CHECKPOINT)
  File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 157, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 205, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "/home/caq/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Spacetimeformer_Forecaster:
    Unexpected key(s) in state_dict: "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.projection_matrix", "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.calls_since_last_redraw", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.bias", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.bias",  ..., "spacetimeformer.decoder.norm.norm.num_batches_tracked". 
    size mismatch for spacetimeformer.embedding.x_emb.embed_weight: copying a param with shape torch.Size([7, 12]) from checkpoint, the shape in current model is torch.Size([5, 6]).
    size mismatch for spacetimeformer.embedding.x_emb.embed_bias: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([6]).
    size mismatch for spacetimeformer.embedding.y_emb.weight: copying a param with shape torch.Size([100, 85]) from checkpoint, the shape in current model is torch.Size([512, 31]).
    size mismatch for spacetimeformer.embedding.y_emb.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([512]).
    size mismatch for spacetimeformer.embedding.var_emb.weight: copying a param with shape torch.Size([20, 100]) from checkpoint, the shape in current model is torch.Size([1, 512]).
    size mismatch for spacetimeformer.embedding.given_emb.weight: copying a param with shape torch.Size([2, 100]) from checkpoint, the shape in current model is torch.Size([2, 512]).
    size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.inner_attention.projection_matrix: copying a param with shape torch.Size([80, 25]) from checkpoint, the shape in current model is torch.Size([266, 64]).
    size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.query_projection.weight: copying a param with shape torch.Size([100, 100]) from checkpoint, the shape in current model is torch.Size([512, 512]).
SingleXuu commented 2 years ago

using this code I get the same error. There are lots of shape missmatches.

import train

CHECKPOINT = 'spacetimeformer/data/stf_model_checkpoints/spatiotemporal_toy2_613732368/spatiotemporal_toy2epoch=45-val/loss=0.05.ckpt'
config = ['spacetimeformer', 'toy2', '--run_name', 'spatiotemporal_toy2',
          '--d_model', 100, '--d_ff', 400, '--enc_layers', 4, '--dec_layers', 4,
          '--gpus','2','3', '--batch_size', 32, '--start_token_len', 4,
          '--n_heads', 4,'--grad_clip_norm', 1, '--early_stopping', '--trials', 1]

parser = train.create_parser('spacetimeformer', 'toy2')  # I modified the function to enable passing model and dset.
config = parser.parse_args([str(i) if type(i) == int else i for i in config])
forecaster = train.create_model(config)
forecaster.load_from_checkpoint(CHECKPOINT)

Errors:

Traceback (most recent call last):
  File "/home/caq/.pycharm_helpers/pydev/pydevd.py", line 1448, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/caq/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/spacetimeformer/spacetimeformer/demo.py", line 37, in <module>
    forecaster.load_from_checkpoint(CHECKPOINT)
  File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 157, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 205, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "/home/caq/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Spacetimeformer_Forecaster:
  Unexpected key(s) in state_dict: "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.projection_matrix", "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.calls_since_last_redraw", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.bias", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.bias",  ..., "spacetimeformer.decoder.norm.norm.num_batches_tracked". 
  size mismatch for spacetimeformer.embedding.x_emb.embed_weight: copying a param with shape torch.Size([7, 12]) from checkpoint, the shape in current model is torch.Size([5, 6]).
  size mismatch for spacetimeformer.embedding.x_emb.embed_bias: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([6]).
  size mismatch for spacetimeformer.embedding.y_emb.weight: copying a param with shape torch.Size([100, 85]) from checkpoint, the shape in current model is torch.Size([512, 31]).
  size mismatch for spacetimeformer.embedding.y_emb.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([512]).
  size mismatch for spacetimeformer.embedding.var_emb.weight: copying a param with shape torch.Size([20, 100]) from checkpoint, the shape in current model is torch.Size([1, 512]).
  size mismatch for spacetimeformer.embedding.given_emb.weight: copying a param with shape torch.Size([2, 100]) from checkpoint, the shape in current model is torch.Size([2, 512]).
  size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.inner_attention.projection_matrix: copying a param with shape torch.Size([80, 25]) from checkpoint, the shape in current model is torch.Size([266, 64]).
  size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.query_projection.weight: copying a param with shape torch.Size([100, 100]) from checkpoint, the shape in current model is torch.Size([512, 512]).