Closed SingleXuu closed 2 years ago
I am not well equipped to debug this because the results used in the paper are based on live logging statistics, and there were very few (if any) times I actually needed to save and load a model from disk. I included that feature very early in the development process, mostly because pytorch lightning does a lot of the work for us. If I had to guess, I would say that you are probably missing kwargs in the model construction (d_model, d_ff, normalization info) that allow the model to build the correct architecture with random weights and then replace them from disk.
using this code I get the same error. There are lots of shape missmatches.
import train
CHECKPOINT = 'spacetimeformer/data/stf_model_checkpoints/spatiotemporal_toy2_613732368/spatiotemporal_toy2epoch=45-val/loss=0.05.ckpt'
config = ['spacetimeformer', 'toy2', '--run_name', 'spatiotemporal_toy2',
'--d_model', 100, '--d_ff', 400, '--enc_layers', 4, '--dec_layers', 4,
'--gpus','2','3', '--batch_size', 32, '--start_token_len', 4,
'--n_heads', 4,'--grad_clip_norm', 1, '--early_stopping', '--trials', 1]
parser = train.create_parser('spacetimeformer', 'toy2') # I modified the function to enable passing model and dset.
config = parser.parse_args([str(i) if type(i) == int else i for i in config])
forecaster = train.create_model(config)
forecaster.load_from_checkpoint(CHECKPOINT)
Errors:
Traceback (most recent call last):
File "/home/caq/.pycharm_helpers/pydev/pydevd.py", line 1448, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/home/caq/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/spacetimeformer/spacetimeformer/demo.py", line 37, in <module>
forecaster.load_from_checkpoint(CHECKPOINT)
File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 157, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 205, in _load_model_state
model.load_state_dict(checkpoint['state_dict'], strict=strict)
File "/home/caq/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Spacetimeformer_Forecaster:
Unexpected key(s) in state_dict: "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.projection_matrix", "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.calls_since_last_redraw", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.bias", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.bias", ..., "spacetimeformer.decoder.norm.norm.num_batches_tracked".
size mismatch for spacetimeformer.embedding.x_emb.embed_weight: copying a param with shape torch.Size([7, 12]) from checkpoint, the shape in current model is torch.Size([5, 6]).
size mismatch for spacetimeformer.embedding.x_emb.embed_bias: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([6]).
size mismatch for spacetimeformer.embedding.y_emb.weight: copying a param with shape torch.Size([100, 85]) from checkpoint, the shape in current model is torch.Size([512, 31]).
size mismatch for spacetimeformer.embedding.y_emb.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for spacetimeformer.embedding.var_emb.weight: copying a param with shape torch.Size([20, 100]) from checkpoint, the shape in current model is torch.Size([1, 512]).
size mismatch for spacetimeformer.embedding.given_emb.weight: copying a param with shape torch.Size([2, 100]) from checkpoint, the shape in current model is torch.Size([2, 512]).
size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.inner_attention.projection_matrix: copying a param with shape torch.Size([80, 25]) from checkpoint, the shape in current model is torch.Size([266, 64]).
size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.query_projection.weight: copying a param with shape torch.Size([100, 100]) from checkpoint, the shape in current model is torch.Size([512, 512]).
using this code I get the same error. There are lots of shape missmatches.
import train CHECKPOINT = 'spacetimeformer/data/stf_model_checkpoints/spatiotemporal_toy2_613732368/spatiotemporal_toy2epoch=45-val/loss=0.05.ckpt' config = ['spacetimeformer', 'toy2', '--run_name', 'spatiotemporal_toy2', '--d_model', 100, '--d_ff', 400, '--enc_layers', 4, '--dec_layers', 4, '--gpus','2','3', '--batch_size', 32, '--start_token_len', 4, '--n_heads', 4,'--grad_clip_norm', 1, '--early_stopping', '--trials', 1] parser = train.create_parser('spacetimeformer', 'toy2') # I modified the function to enable passing model and dset. config = parser.parse_args([str(i) if type(i) == int else i for i in config]) forecaster = train.create_model(config) forecaster.load_from_checkpoint(CHECKPOINT)
Errors:
Traceback (most recent call last): File "/home/caq/.pycharm_helpers/pydev/pydevd.py", line 1448, in _exec pydev_imports.execfile(file, globals, locals) # execute the script File "/home/caq/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "/spacetimeformer/spacetimeformer/demo.py", line 37, in <module> forecaster.load_from_checkpoint(CHECKPOINT) File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 157, in load_from_checkpoint model = cls._load_model_state(checkpoint, strict=strict, **kwargs) File "/home/caq/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 205, in _load_model_state model.load_state_dict(checkpoint['state_dict'], strict=strict) File "/home/caq/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Spacetimeformer_Forecaster: Unexpected key(s) in state_dict: "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.projection_matrix", "spacetimeformer.encoder.attn_layers.2.local_attention.inner_attention.calls_since_last_redraw", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.query_projection.bias", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.weight", "spacetimeformer.encoder.attn_layers.2.local_attention.key_projection.bias", ..., "spacetimeformer.decoder.norm.norm.num_batches_tracked". size mismatch for spacetimeformer.embedding.x_emb.embed_weight: copying a param with shape torch.Size([7, 12]) from checkpoint, the shape in current model is torch.Size([5, 6]). size mismatch for spacetimeformer.embedding.x_emb.embed_bias: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([6]). size mismatch for spacetimeformer.embedding.y_emb.weight: copying a param with shape torch.Size([100, 85]) from checkpoint, the shape in current model is torch.Size([512, 31]). size mismatch for spacetimeformer.embedding.y_emb.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for spacetimeformer.embedding.var_emb.weight: copying a param with shape torch.Size([20, 100]) from checkpoint, the shape in current model is torch.Size([1, 512]). size mismatch for spacetimeformer.embedding.given_emb.weight: copying a param with shape torch.Size([2, 100]) from checkpoint, the shape in current model is torch.Size([2, 512]). size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.inner_attention.projection_matrix: copying a param with shape torch.Size([80, 25]) from checkpoint, the shape in current model is torch.Size([266, 64]). size mismatch for spacetimeformer.encoder.attn_layers.0.global_attention.query_projection.weight: copying a param with shape torch.Size([100, 100]) from checkpoint, the shape in current model is torch.Size([512, 512]).
`model = spacetimeformer_model.Spacetimeformer_Forecaster(d_x=6, d_y=6) model.load_from_checkpoint(check_point) data_module, inv_scaler, null_val = create_dset() trainer = pl.Trainer()
trainer.test(model=model, datamodule=data_module)`
there is an error as RuntimeError: Error(s) in loading state_dict for Spacetimeformer_Forecaster: Unexpected key(s) in state_dict:
could you tell me how to solve it? thank you very much!