Traceback (most recent call last): File "train.py", line 245, in <module> model = LightingModelWrapper.load_from_checkpoint(args.checkpoint, map_location='cpu', args=args) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 139, in load_from_checkpoint return _load_from_checkpoint( File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 188, in _load_from_checkpoint return _load_state(cls, checkpoint, strict=strict, **kwargs) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 247, in _load_state keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightingModelWrapper: size mismatch for model.out_layer.weight: copying a param with shape torch.Size([2, 2048]) from checkpoint, the shape in current model is torch.Size([2, 3072]).
您好: 当我在执行如下命令时:
python train.py --test_only --checkpoint checkpoints/model.ckpt
报出如下错误:Traceback (most recent call last): File "train.py", line 245, in <module> model = LightingModelWrapper.load_from_checkpoint(args.checkpoint, map_location='cpu', args=args) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 139, in load_from_checkpoint return _load_from_checkpoint( File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 188, in _load_from_checkpoint return _load_state(cls, checkpoint, strict=strict, **kwargs) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 234, in _load_state obj = cls(**_cls_kwargs) File "train.py", line 29, in __init__ self.model = model_cls(args,config) File "/home/wangyk/projects/BAM/models/bam.py", line 17, in __init__ self.ssl_layer = getattr(hub, config.ssl_name)(ckpt=config.ssl_ckpt, fairseq=True) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/s3prl/upstream/wavlm/hubconf.py", line 26, in wavlm_local assert os.path.isfile(ckpt)
当我尝试将 upstream/wavlm/hubconf.py 文件内的 ckpt 赋值为本地文件后,重新运行上述命令则报出如下错误:
Traceback (most recent call last): File "train.py", line 245, in <module> model = LightingModelWrapper.load_from_checkpoint(args.checkpoint, map_location='cpu', args=args) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 139, in load_from_checkpoint return _load_from_checkpoint( File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 188, in _load_from_checkpoint return _load_state(cls, checkpoint, strict=strict, **kwargs) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 247, in _load_state keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightingModelWrapper: size mismatch for model.out_layer.weight: copying a param with shape torch.Size([2, 2048]) from checkpoint, the shape in current model is torch.Size([2, 3072]).