bojone / P-tuning

P-tuning方法在中文上的简单实验
138 stars 15 forks source link

转pb文件,PtuningEmbedding层报错 #4

Open MonkeyTB opened 3 years ago

MonkeyTB commented 3 years ago

https://github.com/bojone/P-tuning/blob/aec82943f21268a6c813877ac055631e57cb96c3/bert.py#L122 苏神,再用P-tuning代码保存为h5格式后,转pb文件时(参考https://github.com/bojone/bert4keras/issues/194),会报错ValueError: Unknown layer: PtuningEmbedding,麻烦问下知道这是为啥吗?代码如下 ''' import os os.environ['TF_KERAS'] = '1' import numpy as np import pandas as pd from bert4keras.backend import keras,K from bert4keras.layers import Loss, Embedding from bert4keras.tokenizers import Tokenizer from bert4keras.models import build_transformer_model, BERT from bert4keras.optimizers import Adam from bert4keras.snippets import sequence_padding, DataGenerator from bert4keras.snippets import open from bert4keras.layers import Lambda, Dense from keras.models import load_model import tensorflow as tf from tensorflow.python.framework.ops import disable_eager_execution disable_eager_execution() model = 'model/Bert_Ptuning.h5' base = '/model/pb' keras_model = load_model(model,compile=False) keras_model.save(base + '/Bert_Ptuning/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model '''

bojone commented 3 years ago

load_model的时候传入custom_objects={'PtuningEmbedding': PtuningEmbedding}

MonkeyTB commented 3 years ago

可以了,感谢~