bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 927 forks source link

recompute_grad 不适用 estimator #368

Open Atakey opened 3 years ago

Atakey commented 3 years ago

提问时请尽可能提供如下信息:

基本信息

核心代码

import os
os.environ['TF_KERAS'] = '1'
os.environ['RECOMPUTE'] = '1'
import tempfile
import numpy as np
from absl import app
from absl import flags
import tensorflow as tf

if tf.__version__ >= '2':
    if tf.__version__ < '2.1':
        physical_devices = tf.config.experimental.list_physical_devices('GPU')
    else:
        physical_devices = tf.config.list_physical_devices('GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
else:
    from tensorflow.keras import backend as K

    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    K.set_session(sess)
from bert4keras.snippets import DataGenerator, sequence_padding
from bert4keras.models import build_transformer_model, Dense, Model

FLAGS = flags.FLAGS

base_dir = r'xxxxxxxxx\chinese_L-12_H-768_A-12'

flags.DEFINE_string("bert_config_file", os.path.join(base_dir, 'bert_config.json'), '')
flags.DEFINE_string("vocab_file", os.path.join(base_dir, 'vocab.txt'), '')
flags.DEFINE_string("init_checkpoint", os.path.join(base_dir, 'bert_model.ckpt'), '')
flags.DEFINE_integer("train_batch_size", 32, '')

class data_generator(DataGenerator):
    """
    keras_model 数据生成器
    """

    def __init__(self, data, batch_size=None):
        super().__init__(data, batch_size=batch_size)

    def __iter__(self, random=True):

        batch_token_ids, batch_token_segment_ids, batch_labels = [], [], []
        while 1:
            size = 32
            token_ids = [1234 for i in range(size)]
            label = [np.random.randint(2)]
            batch_token_ids.append(token_ids)
            token_segment_ids = [0] * size
            batch_token_segment_ids.append(token_segment_ids)
            batch_labels.append(label)

            if len(batch_token_ids) == self.batch_size:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_token_segment_ids = sequence_padding(batch_token_segment_ids)
                batch_labels = np.array(batch_labels)
                yield [batch_token_ids, batch_token_segment_ids], batch_labels
                batch_token_ids, batch_token_segment_ids, batch_labels = [], [], []

class estimator_data_generator(DataGenerator):
    """
    estimator 数据生成器
    """

    def __init__(self, data, batch_size=None):
        super().__init__(data, batch_size=batch_size)

    def __iter__(self, random=True):
        while 1:
            size = np.random.randint(1, 33)
            token_ids = [1234 for i in range(size)]
            label = [np.random.randint(2)]
            token_segment_ids = [0] * size
            yield token_ids, token_segment_ids, label

def input_fn_builder():
    data_gen = estimator_data_generator([], batch_size=FLAGS.train_batch_size)

    def _decode_record(token, segment, label):
        return ({"Input-Token": token,
                 'Input-Segment': segment}, {'label': label})

    def input_fn():
        d = tf.data.Dataset.from_generator(data_gen.forfit,
                                           output_types=(tf.int64, tf.int64, tf.int64),
                                           output_shapes=(tf.TensorShape([None]), tf.TensorShape([None]),
                                                          tf.TensorShape([1])),
                                           )

        d = d.map(lambda token, segment, label: _decode_record(token, segment, label))
        d = d.padded_batch(FLAGS.train_batch_size,
                           padded_shapes=({'Input-Token': [None], 'Input-Segment': [None]},
                                          {'label': [None]}))
        return d

    return input_fn

def main(_):
    model = build_transformer_model(
        config_path=FLAGS.bert_config_file,
        checkpoint_path=FLAGS.init_checkpoint,
        model='bert',
        return_keras_model=True,
        with_pool=True)

    output = Dense(1, activation='sigmoid', name='label')(model.outputs[0])
    model = Model(model.inputs, output)
    model.compile(loss='binary_crossentropy',
                  optimizer='adam',
                  metrics=['acc'])
    # 非estimator模式,采用不同版本的tf 2.1、2.2、2.3,大部分recompute_grad能生效
    # data_gen = data_generator([], batch_size=FLAGS.train_batch_size)
    # model.fit_generator(data_gen.forfit(),
    #                     steps_per_epoch=40,
    #                     epochs=3, )

    # estimator模式recompute_grad失效
    with tempfile.TemporaryDirectory() as model_dir:
        estimator = tf.keras.estimator.model_to_estimator(keras_model=model,
                                                          model_dir=model_dir,
                                                          keras_model_path=None,
                                                          custom_objects=None,
                                                          config=None,
                                                          checkpoint_format='saver')
        tensors_to_log = {"outputs": 'label/Sigmoid:0'}
        logging_hook = tf.estimator.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=40)
        estimator.train(input_fn=input_fn_builder(), steps=400, hooks=[logging_hook])

if __name__ == "__main__":
    app.run(main)

自我尝试

重计算 测试情况 keras=2.3.1 tf2.1 + keras 重计算可用 tf2.1 + tfkeras 重计算可用

tf2.2 + keras 重计算可用 tf2.2 + tfkeras 重计算可用

tf2.3 + keras AttributeError: module 'tensorflow.python.framework.ops' has no attribute '_TensorLike' # 应该是keras版本过低不适配tf2.3的原因 tf2.3 + tfkeras 重计算可用

而无论采用tf 2.1、2.2、2.3的哪个版本,model_to_estimator之后训练,重计算均失效。

Atakey commented 3 years ago

将 os.environ['RECOMPUTE'] 无论设置为0还是1,训练速度 均没有出现明显差异,例如当前模型代码下同样的 batch_size, tf2.1 + keras 采用recompute_grad 比未采用重计算训练速度大约慢20%左右

bojone commented 3 years ago

关于recompute的各种问题,欢迎提出解决方案,但是如果纯粹提出问题,那么作者也无能为力,因为recompute的代码是纯粹东拼西凑而来的,作者也不理解原理,无从改进,更不用说结果estimator这种作者也没用过的东西的使用方法了。