https://github.com/bojone/P-tuning/blob/aec82943f21268a6c813877ac055631e57cb96c3/bert.py#L122
苏神,再用P-tuning代码保存为h5格式后,转pb文件时(参考https://github.com/bojone/bert4keras/issues/194),会报错ValueError: Unknown layer: PtuningEmbedding,麻烦问下知道这是为啥吗?代码如下
'''
import os
os.environ['TF_KERAS'] = '1'
import numpy as np
import pandas as pd
from bert4keras.backend import keras,K
from bert4keras.layers import Loss, Embedding
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model, BERT
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
from bert4keras.layers import Lambda, Dense
from keras.models import load_model
import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
model = 'model/Bert_Ptuning.h5'
base = '/model/pb'
keras_model = load_model(model,compile=False)
keras_model.save(base + '/Bert_Ptuning/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model
'''
https://github.com/bojone/P-tuning/blob/aec82943f21268a6c813877ac055631e57cb96c3/bert.py#L122 苏神,再用P-tuning代码保存为h5格式后,转pb文件时(参考https://github.com/bojone/bert4keras/issues/194),会报错ValueError: Unknown layer: PtuningEmbedding,麻烦问下知道这是为啥吗?代码如下 ''' import os os.environ['TF_KERAS'] = '1' import numpy as np import pandas as pd from bert4keras.backend import keras,K from bert4keras.layers import Loss, Embedding from bert4keras.tokenizers import Tokenizer from bert4keras.models import build_transformer_model, BERT from bert4keras.optimizers import Adam from bert4keras.snippets import sequence_padding, DataGenerator from bert4keras.snippets import open from bert4keras.layers import Lambda, Dense from keras.models import load_model import tensorflow as tf from tensorflow.python.framework.ops import disable_eager_execution disable_eager_execution() model = 'model/Bert_Ptuning.h5' base = '/model/pb' keras_model = load_model(model,compile=False) keras_model.save(base + '/Bert_Ptuning/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model '''