Open snakecy opened 4 years ago
这个,目前还没试过,如果模型跟bert一样话,那应该能加载,不一样的话暂时就不行了。
这个,目前还没试过,如果模型跟bert一样话,那应该能加载,不一样的话暂时就不行了。
spanbert是torch模型,有没有什么方法可以对齐进行转化?在网上搜索了一个方案,参数对不齐。请帮忙看看,谢谢,附代码如下:
import os
import numpy as np
import tensorflow as tf
import torch
v_list=[]
base_dir = r'spanbert_hf_base'
torch_model = os.path.join(base_dir, 'pytorch_model.bin')
raw_ckpt_model = r'chinese_L-12_H-768_A-12\bert_model.ckpt'
new_ckpt_model = r'./bert_model.ckpt'
for k,v in torch.load(torch_model, map_location='cpu').items():
v_list.append(np.array(v))
def change(ckpt_path, new_ckpt_path):
index = 0
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(ckpt_path):
print(var_name)
var = tf.contrib.framework.load_variable(ckpt_path, var_name)
var = tf.Variable(v_list[index])
index+=1
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, new_ckpt_path)
# ckpt_path = './chinese_L-12_H-768_A-12/bert_model.ckpt'
# new_ckpt_path = './bert_model.ckpt'
change(raw_ckpt_model, new_ckpt_model)
需要确认一下模型结构的一致性才能转。我抽空试试吧。
提问时请尽可能提供如下信息:
基本信息
核心代码
输出信息
自我尝试
不管什么问题,请先尝试自行解决,“万般努力”之下仍然无法解决再来提问。此处请贴上你的努力过程。