bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 929 forks source link

EMA优化报错 #225

Open yifan-chen-2020 opened 4 years ago

yifan-chen-2020 commented 4 years ago

提问时请尽可能提供如下信息:

基本信息

核心代码

def buildmodel():
    model = Model(model.input, output)
    model.summary()
    AdamEMA = bert4keras.optimizers.extend_with_exponential_moving_average_v2(Adam, name='AdamEMA')
    new_adam = AdamEMA(learning_rate=learning_rate)
    model.compile(
        loss=CRF.sparse_loss,
        optimizer=new_adam,
        metrics=[CRF.sparse_accuracy]
    )
    return model,CRF, new_adam

....

class Evaluator(keras.callbacks.Callback):
    def __init__(self,valid_data):
        self.best_val_f1 = 0
        self.valid_data = valid_data

    def on_epoch_end(self, epoch, logs=None):
        optimizer.apply_ema_weights() # 报错点

....

model.fit_generator(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=epochs,
        callbacks=[evaluator],
        verbose=2
    )

输出信息

AttributeError: 'AdamEMA' object has no attribute 'model_weights'

自我尝试

尝试使用普通版本(非v2),然后训练时间从1分半降为17s,出现divide by zero error,模型基本不变,f1 score稳定为0,precision为1,recall 为0. 其它extension,如piece wise linear,可正常使用。

bojone commented 4 years ago

keras不能用v2版本,只能用extend_with_exponential_moving_average,bert4keras会自动识别的。

我这里没试过divide by zero error,具体是什么时候出现的?哪一行报的错?

yifan-chen-2020 commented 4 years ago

研究了一下发现是self.iteration = 0

这里报错:

            self.old_weights = K.batch_get_value(self.model_weights)
            ema_weights = K.batch_get_value(self.ema_weights)

            if bias_correction:
                iterations = K.eval(self.iterations) # 这里等于0
                scale = 1.0 - np.power(self.ema_momentum, iterations)
                ema_weights = [weight / scale for weight in ema_weights] # 这里报错
bojone commented 4 years ago

你没有update过就跑apply_ema_weights了吗?按道理只要更新过一步就不会是零了呀。