bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 927 forks source link

预训练增加监督任务 #305

Closed stevewyl closed 3 years ago

stevewyl commented 3 years ago

提问时请尽可能提供如下信息:

基本信息

修改思路

希望在原有的MLM任务上增加一个监督的分类任务来提升预训练模型在微调任务时的表现 文本较短且不存在语义联系,故改成长度128的单个文档,不进行文档间的拼接

核心代码

# 请在此处贴上你的核心代码。
# 请尽量只保留关键部分,不要无脑贴全部代码。

# 将paragraph_process修改为如下函数
def avs_process(self, texts, starts, ends, paddings):
    cate_text, attr_text, sku_text, label = texts
    cate_token_ids, cate_mask_ids = self.sentence_process(cate_text, False)
    attr_token_ids, attr_mask_ids = self.sentence_process(attr_text, False)
    sku_token_ids, sku_mask_ids = self.sentence_process(sku_text) # 只针对sku_text进行mask操作

    token_ids = [starts[0]] + cate_token_ids + [self.token_sep_id] + attr_token_ids + [self.token_sep_id] + sku_token_ids
    mask_ids = [starts[1]] + cate_mask_ids + [0] + attr_mask_ids + [0] + sku_mask_ids
    segment_ids = [0] * (len(cate_token_ids) + len(attr_token_ids) + 3) + [1] * len(sku_token_ids)
    if len(token_ids) > self.sequence_length - 1:
        token_ids = token_ids[:-1]
        mask_ids = mask_ids[:-1]
        segment_ids = segment_ids[:-1]
    assert len(token_ids) == len(mask_ids) == len(segment_ids)

    instance = [token_ids, mask_ids, segment_ids]
    for item, end, pad in zip(instance, ends, paddings):
        item.append(end)
        item = self.padding(item, pad)
    if not self.only_mlm:
        instance.append(int(label))

    return [instance]

# 修改segment_ids和新增has_answer
# only_mlm表示只开启mlm任务
def load_tfrecord(record_names, sequence_length, batch_size, only_mlm=False):
    """给原方法补上parse_function
    """
    def parse_function(serialized):
        features = {
            'token_ids': tf.io.FixedLenFeature([sequence_length], tf.int64),
            'mask_ids': tf.io.FixedLenFeature([sequence_length], tf.int64),
            'segment_ids': tf.io.FixedLenFeature([sequence_length], tf.int64)
        }
        if not only_mlm:
            features["has_answer"] = tf.io.FixedLenFeature([1], tf.int64)
        features = tf.io.parse_single_example(serialized, features)
        token_ids = features['token_ids']
        mask_ids = features['mask_ids']
        segment_ids = features['segment_ids']
        # segment_ids = K.zeros_like(token_ids, dtype='int64')
        is_masked = K.not_equal(mask_ids, 0)
        masked_token_ids = K.switch(is_masked, mask_ids - 1, token_ids)
        x = {
            'Input-Token': masked_token_ids,
            'Input-Segment': segment_ids,
            'token_ids': token_ids,
            'is_masked': K.cast(is_masked, K.floatx()),
        }
        y = {
            'mlm_loss': K.zeros([1]),
            'mlm_acc': K.zeros([1]),
        }
        if not only_mlm:
            x["has_answer"] = K.cast(tf.reshape(tf.one_hot(features["has_answer"], 2), [-1]), K.floatx())
            y["ans_loss"] = K.zeros([1])
            y["ans_acc"] = K.zeros([1])
        return x, y

    return TrainingDataset.load_tfrecord(
        record_names, batch_size, parse_function
    )

# 修改模型的输入和输出
def build_transformer_model_with_mlm(config_path, only_mlm=False):
    """带mlm的bert模型
    """
    bert = build_transformer_model(
        config_path, with_mlm=True, with_nsp=not only_mlm, return_keras_model=False
    )

    if only_mlm:
        proba = bert.model.output
    else:
        ans_preds = bert.model.output[0] # 取NSP的分类概率作为二分类监督任务的预测概率
        proba = bert.model.output[1]

    # 辅助输入
    token_ids = Input(shape=(None,), dtype='int64', name='token_ids')  # 目标id
    is_masked = Input(shape=(None,), dtype=K.floatx(), name='is_masked')  # mask标记
    model_inputs = bert.model.inputs + [token_ids, is_masked]
    if not only_mlm:
        has_answer = Input(shape=(2,), dtype=K.floatx(), name='has_answer')  # 辅助任务 判断是否包含属性信息
        model_inputs.append(has_answer)

    def mlm_acc

    def mlm_loss

    def ans_loss(inputs):
        """计算监督任务loss的函数,需要封装为一个层
        """
        y_true, y_pred = inputs
        loss = K.binary_crossentropy(y_true, y_pred)
        loss = K.mean(loss, -1)
        return loss

    def ans_acc(inputs):
        """计算监督任务准确率的函数,需要封装为一个层
        """
        y_true, y_pred = inputs
        y_true = K.cast(y_true, K.floatx())
        acc = keras.metrics.binary_accuracy(y_true, y_pred)
        return acc

    mlm_loss = Lambda(mlm_loss, name='mlm_loss')([token_ids, proba, is_masked])
    mlm_acc = Lambda(mlm_acc, name='mlm_acc')([token_ids, proba, is_masked])
    model_outputs = [mlm_loss, mlm_acc]
    if not only_mlm:
        ans_loss = Lambda(ans_loss, name='ans_loss')([has_answer, ans_preds])
        ans_acc = Lambda(ans_acc, name='ans_acc')([has_answer, ans_preds])
        model_outputs.append(ans_loss)
        model_outputs.append(ans_acc)

    train_model = Model(model_inputs, model_outputs)

    loss = {
        'mlm_loss': lambda y_true, y_pred: y_pred,
        'mlm_acc': lambda y_true, y_pred: K.stop_gradient(y_pred),
    }
    if not only_mlm:
        loss['ans_loss'] = lambda y_true, y_pred: y_pred
        loss['ans_acc'] = lambda y_true, y_pred: K.stop_gradient(y_pred)

    return bert, train_model, loss

输出信息

# 请在此处贴上你的调试输出
# 无论是否开启only_mlm选项均会出现如下类似的错误,TFRecords数据解析错误

2021-03-13 13:39:16.197852: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at example_parsing_ops.cc:240 : Invalid argument: Key: mask_ids.  Can't parse serialized Example.
Traceback (most recent call last):
  File "run_nsp.py", line 230, in <module>
    callbacks=[checkpoint, csv_logger]
  File "/home/stevewyl/anaconda3/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 727, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/stevewyl/anaconda3/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 675, in fit
    steps_name='steps_per_epoch')
  File "/home/stevewyl/anaconda3/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 300, in model_iteration
    batch_outs = f(actual_inputs)
  File "/home/stevewyl/anaconda3/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3476, in __call__
    run_metadata=self.run_metadata)
  File "/home/stevewyl/anaconda3/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1472, in __call__
    run_metadata_ptr)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: {{function_node __inference_Dataset_map_parse_function_47}} Key: mask_ids.  Can't parse serialized Example.
         [[{{node ParseSingleExample/ParseSingleExample}}]]
         [[IteratorGetNext]]
  (1) Invalid argument: {{function_node __inference_Dataset_map_parse_function_47}} Key: mask_ids.  Can't parse serialized Example.
         [[{{node ParseSingleExample/ParseSingleExample}}]]
         [[IteratorGetNext]]
         [[IteratorGetNext/_2067]]

自我尝试

不管什么问题,请先尝试自行解决,“万般努力”之下仍然无法解决再来提问。此处请贴上你的努力过程。

  1. 选择data_utils.py处理开源语料,使用pretraining.py进行mlm任务训练,正常运行
  2. 选择data_utils.py处理开源语料,使用pretraining_new.py + only_mlm选项进行mlm任务训练,正常运行
  3. 选择data_utils_new.py(即我修改过的)处理任务语料,使用pretraining_new.py进行训练,无论是否开启only_mlm选项,均出现如上错误
  4. 单独加载TFRecords数据集,无异常;model.summary()正常打印 -> 应该是数据解析阶段出了问题

数据打印正常

for example in tf.io.tf_record_iterator(data_fn):
    print(tf.train.Example.FromString(example))

features {
  feature {
    key: "has_answer"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "mask_ids"
    value {
      int64_list {
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 104
        value: 104
        value: 104
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 104
        value: 104
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
      }
    }
  }
  feature {
    key: "segment_ids"
    value {
      int64_list {
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 0
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
        value: 1
      }
    }
  }
  feature {
    key: "token_ids"
    value {
      int64_list {
        value: 101
        value: 2339
        value: 3302
        value: 102
        value: 1331
        value: 2428
        value: 102
        value: 4294
        value: 3821
        value: 3294
        value: 1103
        value: 7226
        value: 6132
        value: 2137
        value: 1169
        value: 138
        value: 163
        value: 11316
        value: 8303
        value: 8148
        value: 140
        value: 676
        value: 1394
        value: 671
        value: 2339
        value: 868
        value: 3302
        value: 2137
        value: 1169
        value: 1313
        value: 8529
        value: 138
        value: 163
        value: 11316
        value: 8303
        value: 8148
        value: 140
        value: 1310
        value: 6132
        value: 1103
        value: 7226
        value: 6132
        value: 4511
        value: 1957
        value: 1217
        value: 1331
        value: 4633
        value: 2255
        value: 3302
        value: 1912
        value: 1947
        value: 2339
        value: 6163
        value: 2137
        value: 976
        value: 8342
        value: 6132
        value: 3302
        value: 150
        value: 8177
        value: 8978
        value: 8393
        value: 8144
        value: 102
      }
    }
  }
}

求教苏神有什么debug思路吗?谢谢!

stevewyl commented 3 years ago

是数据生成阶段的padding没加上 已正常运行

stevewyl commented 3 years ago

想请教下苏神关于预训练的问题。目前模型已经正常训练起来了,但是loss变化非常奇怪。 每个epoch为250个step,mlm_acc大概经过5-6个epoch后,开始出现不断下降

epoch,loss,mlm_acc_loss,mlm_loss_loss 0,9.975207859039307,0.44808972,9.527121 1,9.961372146606445,0.5752151,9.386159 2,9.960256561279296,0.5910834,9.369174 3,9.959727100372314,0.6022016,9.357516 4,9.95986014175415,0.60329336,9.356565 5,9.959882137298584,0.5863511,9.373529 6,9.9597060546875,0.58159727,9.378114 7,9.959792503356933,0.5641684,9.395623

预训练的语料为,如 “罗技(Logitech)适用于Mac的MX Master 3无线蓝牙优联双模跨计算机控制鼠标-深空灰” 这样的非正常语义的商品标题,是不是这种文本不适合做预训练任务,或者说和加载的预训练模型的训练语料差异太大,应该从头开始训练一个商品标题语言模型?

bojone commented 3 years ago

想请教下苏神关于预训练的问题。目前模型已经正常训练起来了,但是loss变化非常奇怪。 每个epoch为250个step,mlm_acc大概经过5-6个epoch后,开始出现不断下降

epoch,loss,mlm_acc_loss,mlm_loss_loss 0,9.975207859039307,0.44808972,9.527121 1,9.961372146606445,0.5752151,9.386159 2,9.960256561279296,0.5910834,9.369174 3,9.959727100372314,0.6022016,9.357516 4,9.95986014175415,0.60329336,9.356565 5,9.959882137298584,0.5863511,9.373529 6,9.9597060546875,0.58159727,9.378114 7,9.959792503356933,0.5641684,9.395623

预训练的语料为,如 “罗技(Logitech)适用于Mac的MX Master 3无线蓝牙优联双模跨计算机控制鼠标-深空灰” 这样的非正常语义的商品标题,是不是这种文本不适合做预训练任务,或者说和加载的预训练模型的训练语料差异太大,应该从头开始训练一个商品标题语言模型?

你mlm_acc都60了,mlm_loss还这么高,看上去不大正常吧,是不是哪里写错了。。。

stevewyl commented 3 years ago

我该用bert4keras/pretraining,py脚本在我的商品标题数据上进行mlm任务训练 只调整了下batch_size(改为64)和grad_accum_steps(改为1,改为其他值会出现错误)还是会出现这样的情况

epoch,loss,mlm_acc_loss,mlm_loss_loss 0,9.975546821594238,0.4527319,9.522816 1,9.961218914031983,0.57211065,9.389108 2,9.96013719177246,0.58912295,9.371013 3,9.959897575378418,0.60324585,9.356652 4,9.959748355865479,0.59947634,9.360272 5,9.959580730438232,0.5856575,9.373923 6,9.959896369934082,0.57579464,9.3841 7,9.959613983154297,0.558508,9.401109 8,9.959290103912354,0.29001126,9.669282 9,9.958436988830567,0.04440407,9.914034 10,9.95843703842163,0.042420294,9.916016 11,9.958437007904053,0.04402391,9.91441 12,9.958448265075683,0.042767487,9.91568 13,9.958436534881592,0.042882875,9.915551 14,9.958435176849365,0.044536836,9.913898 15,9.958435420989991,0.043750964,9.914683

大概还是因为语料不匹配的原因吧

stevewyl commented 3 years ago

@bojone 调整了下学习率,现在loss和acc(维持在0.58左右)比较稳定了。想问下苏神大概用的多少机器资源训预训练模型,看代码中学习率设的1e-3,batch_size设的4096。

rxc205 commented 3 years ago

46606445,0.5752151,9.386159 2,9.960256561279296,0.5910834,9.369174 3,9.959727100372314,0.6022016,9.357516 4,9.95986014175415,0.60329336,9.356565 5,9.959882137298584,0.5863511,9.373529 6,9.9597060546875,0.58159727,9.378

你好,请问你用此方法预训练产生的模型,你验证了吗,效果如何,方便沟通一下吗?感谢