media-sec-lab / BAM

The pytorch implementation of BAM for Partialspoof Audio Localization.
13 stars 2 forks source link

assert os.path.isfile(ckpt) #3

Closed kingback2019 closed 3 months ago

kingback2019 commented 3 months ago

您好: 当我在执行如下命令时: python train.py --test_only --checkpoint checkpoints/model.ckpt 报出如下错误: Traceback (most recent call last): File "train.py", line 245, in <module> model = LightingModelWrapper.load_from_checkpoint(args.checkpoint, map_location='cpu', args=args) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 139, in load_from_checkpoint return _load_from_checkpoint( File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 188, in _load_from_checkpoint return _load_state(cls, checkpoint, strict=strict, **kwargs) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 234, in _load_state obj = cls(**_cls_kwargs) File "train.py", line 29, in __init__ self.model = model_cls(args,config) File "/home/wangyk/projects/BAM/models/bam.py", line 17, in __init__ self.ssl_layer = getattr(hub, config.ssl_name)(ckpt=config.ssl_ckpt, fairseq=True) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/s3prl/upstream/wavlm/hubconf.py", line 26, in wavlm_local assert os.path.isfile(ckpt)

当我尝试将 upstream/wavlm/hubconf.py 文件内的 ckpt 赋值为本地文件后,重新运行上述命令则报出如下错误:

Traceback (most recent call last): File "train.py", line 245, in <module> model = LightingModelWrapper.load_from_checkpoint(args.checkpoint, map_location='cpu', args=args) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 139, in load_from_checkpoint return _load_from_checkpoint( File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 188, in _load_from_checkpoint return _load_state(cls, checkpoint, strict=strict, **kwargs) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/lightning/pytorch/core/saving.py", line 247, in _load_state keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict) File "/home/wangyk/anaconda3/envs/bam/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightingModelWrapper: size mismatch for model.out_layer.weight: copying a param with shape torch.Size([2, 2048]) from checkpoint, the shape in current model is torch.Size([2, 3072]).

ZhongJiafeng-16 commented 3 months ago

抱歉,由于我的疏忽,模型结构配置含有错误的参数。 请将models/bam.py的 line 37 中的代码为:

self.out_layer = nn.Linear(in_features=2*config.embed_dim,out_features=2)