Open jxyxiangyu opened 2 years ago
提问时请尽可能提供如下信息:
# 参考的是您本仓库举的例子task_iflytek_adversarial_training.py中的一段代码 # 加载预训练模型 bert = build_transformer_model( config_path=config_path, checkpoint_path=checkpoint_path, return_keras_model=False, ) output = Lambda(lambda x: x[:, 0])(bert.model.output) output = Dense( units=num_classes, activation='sigmoid', kernel_initializer=bert.initializer )(output) model = keras.models.Model(bert.model.input, output) # 预测部分 token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) pred = model.predict([[token_ids], [segment_ids]]) 请问,bert4keras可以像pytorch一样支持input_mask/attention_mask, label_ids这些输入吗? 如何才能再输入端支持这样的输入呢?我拜读了您写的源码,奈何我能力有限,只看到了bert的输入是Input
补充:另外可以问下模型输入的Input-Segment和Input-Token可以自己自定义成其他名字吗?可以在外面再包一层接口吗?
attention_mask、Input-Segment和Input-Token改名,可以通过继承Transformer类/BERT类等来实现。
另外,bert4keras只负责构建bert(transformer),其余大部分需求是个人通过学习keras后实现的。
十分感谢
提问时请尽可能提供如下信息:
基本信息
核心代码