bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 927 forks source link

Keras LayerNormalization after bert.model throw exception after save and load model #100

Closed luoy2 closed 4 years ago

luoy2 commented 4 years ago

Env:

python: 3.7.6 tensorflow: 2.1.0 bert4keras: 0.6.5 keras: 2.2.4-tf

Replicate

import os
os.environ["TF_KERAS"] = "1"
import tensorflow as tf
from tensorflow.keras.layers import *
from bert4keras.models import build_transformer_model

bert_config_path = 'pre_language_models/chinese_L-12_H-768_A-12/bert_config.json'
bert_checkpoint_path = 'pre_language_models/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'pre_language_models/chinese_L-12_H-768_A-12/vocab.txt'

bert = build_transformer_model(
    config_path=bert_config_path,
    checkpoint_path=bert_checkpoint_path,
    # model='albert',
    with_pool=True,
    return_keras_model=False,
)
for i in bert.model.layers[:-1]:
    i.trainable = False

bert_feature_norm = tf.keras.layers.LayerNormalization(axis=-1)(bert.output)
model = tf.keras.Model(bert.input, bert_feature_norm)
json_config = model.to_json()
reinitialized_model = tf.keras.models.model_from_json(json_config)

Expected output

load model without error

Actual output

Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-29-f3aca118058e>", line 1, in <module>
    reinitialized_model = tf.keras.models.model_from_json(json_config)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\saving\model_config.py", line 96, in model_from_json
    return deserialize(config, custom_objects=custom_objects)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\layers\serialization.py", line 106, in deserialize
    printable_module_name='layer')
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py", line 303, in deserialize_keras_object
    list(custom_objects.items())))
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\engine\network.py", line 937, in from_config
    config, custom_objects)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\engine\network.py", line 1893, in reconstruct_from_config
    process_layer(layer_data)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\engine\network.py", line 1875, in process_layer
    layer = deserialize_layer(layer_data, custom_objects=custom_objects)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\layers\serialization.py", line 106, in deserialize
    printable_module_name='layer')
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py", line 305, in deserialize_keras_object
    return cls.from_config(cls_config)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 519, in from_config
    return cls(**config)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\bert4keras\layers.py", line 217, in __init__
    super(LayerNormalization, self).__init__(**kwargs)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\bert4keras\layers.py", line 78, in __init__
    super(Layer, self).__init__(**kwargs)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\training\tracking\base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 186, in __init__
    generic_utils.validate_kwargs(kwargs, allowed_kwargs)
  File "C:\ProgramData\Anaconda3\envs\tf2gpu\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py", line 718, in validate_kwargs
    raise TypeError(error_message, kwarg)
TypeError: ('Keyword argument not understood:', 'axis')

Possible reason:

Keras had its own implementation of LayerNorm layer, and it has an input axis, and does not have an input. Thus, after save model and reload, tf does not know which Layernorm should be implemented.

bojone commented 4 years ago

这并不是什么bug,是你自己使用不规范而已。

同名层混淆使用本身就是一个很不好的习惯。

当然,更重要的是:选择tf 2本身就是一个相当糟糕的选择。