bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 926 forks source link

是否支持spanbert #176

Open snakecy opened 4 years ago

snakecy commented 4 years ago

提问时请尽可能提供如下信息:

  1. 如题,是否支持spanbert?
  2. 关系抽取还有什么建议改进的方向?

基本信息

核心代码

# 请在此处贴上你的核心代码。
# 请尽量只保留关键部分,不要无脑贴全部代码。

输出信息

# 请在此处贴上你的调试输出

自我尝试

不管什么问题,请先尝试自行解决,“万般努力”之下仍然无法解决再来提问。此处请贴上你的努力过程。

bojone commented 4 years ago

这个,目前还没试过,如果模型跟bert一样话,那应该能加载,不一样的话暂时就不行了。

snakecy commented 4 years ago

这个,目前还没试过,如果模型跟bert一样话,那应该能加载,不一样的话暂时就不行了。

spanbert是torch模型,有没有什么方法可以对齐进行转化?在网上搜索了一个方案,参数对不齐。请帮忙看看,谢谢,附代码如下:


import os
import numpy as np
import tensorflow as tf
import torch
v_list=[]

base_dir = r'spanbert_hf_base'

torch_model = os.path.join(base_dir, 'pytorch_model.bin')

raw_ckpt_model = r'chinese_L-12_H-768_A-12\bert_model.ckpt'

new_ckpt_model = r'./bert_model.ckpt'

for k,v in torch.load(torch_model, map_location='cpu').items():
    v_list.append(np.array(v))

def change(ckpt_path, new_ckpt_path):
    index = 0
    with tf.Session() as sess:
        for var_name, _ in tf.contrib.framework.list_variables(ckpt_path):
            print(var_name)
            var = tf.contrib.framework.load_variable(ckpt_path, var_name)
            var = tf.Variable(v_list[index])
            index+=1

        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        saver.save(sess, new_ckpt_path)

# ckpt_path = './chinese_L-12_H-768_A-12/bert_model.ckpt'
# new_ckpt_path = './bert_model.ckpt'
change(raw_ckpt_model, new_ckpt_model)
bojone commented 4 years ago

需要确认一下模型结构的一致性才能转。我抽空试试吧。