Open lxyjyy opened 6 months ago
指的是转参数的代码?
@lyuwenyu 对的,转换paddle模型为pytorch模型,我这边通过onnx转pytorch模型,会有很多错误,想问下作者您是通过什么样的脚本文件转的
https://github.com/lyuwenyu/cvperception/blob/main/src/cvperception/zoo/rtdetr/conver_params.py
def main(args) -> None:
import torch
from cvperception.core import YAMLConfig
model = YAMLConfig(args.config).model
if args.version == 1:
state = model.state_dict()
keys = [k for k in state.keys() if 'num_batches_tracked' not in k]
elif args.version == 2:
state = model.state_dict()
ignore_keys = ['anchors', 'valid_mask', 'num_points_scale']
keys = [k for k in state.keys() if 'num_batches_tracked' not in k]
keys = [k for k in keys if not any([x in k for x in ignore_keys])]
import paddle
p_state = paddle.load(args.pdparams)
pkeys = list(p_state.keys())
assert len(keys) == len(pkeys), f'{len(keys)}, {len(pkeys)}'
new_state = {}
for i, k in enumerate(keys):
pp = p_state[pkeys[i]]
pp = torch.tensor(pp.numpy())
if 'denoising_class_embed' in k:
new_state[k] = torch.concat([pp, torch.zeros(1, pp.shape[-1])], dim=0)
continue
tp = state[k]
if len(tp.shape) == 2:
new_state[k] = pp.T
elif len(tp.shape) == 1:
new_state[k] = pp
else:
assert tp.shape == pp.shape, f'{k}, {pp.shape}, {tp.shape}'
new_state[k] = pp
assert len(new_state) == len(p_state), ''
# checkpoint = {'ema': {'module': new_state, }}
# torch.save(checkpoint, args.output_file)
model.load_state_dict(new_state, strict=False)
checkpoint = {'ema': {'module': model.state_dict(), }}
torch.save(checkpoint, args.output_file)
非常感谢,我试下代码
你好 请问一下可以提供一下 paddle2pytorch的脚本文件吗?