bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 929 forks source link

Bert模型输入端是否支持像input_mask,label_ids这些输入呢? #489

Open jxyxiangyu opened 2 years ago

jxyxiangyu commented 2 years ago

提问时请尽可能提供如下信息:

基本信息

核心代码


# 参考的是您本仓库举的例子task_iflytek_adversarial_training.py中的一段代码
# 加载预训练模型
bert = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    return_keras_model=False,
)

output = Lambda(lambda x: x[:, 0])(bert.model.output)
output = Dense(
    units=num_classes,
    activation='sigmoid',
    kernel_initializer=bert.initializer
)(output)

model = keras.models.Model(bert.model.input, output)
# 预测部分
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
pred = model.predict([[token_ids], [segment_ids]])

请问,bert4keras可以像pytorch一样支持input_mask/attention_mask, label_ids这些输入吗?
如何才能再输入端支持这样的输入呢?我拜读了您写的源码,奈何我能力有限,只看到了bert的输入是Input
jxyxiangyu commented 2 years ago

补充:另外可以问下模型输入的Input-Segment和Input-Token可以自己自定义成其他名字吗?可以在外面再包一层接口吗?

bojone commented 2 years ago

attention_mask、Input-Segment和Input-Token改名,可以通过继承Transformer类/BERT类等来实现。

另外,bert4keras只负责构建bert(transformer),其余大部分需求是个人通过学习keras后实现的。

jxyxiangyu commented 2 years ago

十分感谢