kpe / bert-for-tf2

A Keras TensorFlow 2.0 implementation of BERT, ALBERT and adapter-BERT.
https://github.com/kpe/bert-for-tf2
MIT License
803 stars 193 forks source link

Failed to get weights from pretrained google model #79

Closed 121eddie closed 3 years ago

121eddie commented 3 years ago

Windows 10 64, Python 3.7.9, Tensorflow 2.3.1 GPU, Cuda 10.1.243

Trying to import pretrained google bert


import os, bert

model_name = "multi_cased_L-12_H-768_A-12"

model_dir = bert.fetch_google_bert_model(model_name, ".models")
model_ckpt = os.path.join(model_dir, "bert_model.ckpt")

bert_params = bert.params_from_pretrained_ckpt(model_dir)
l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")

bert.load_bert_weights(l_bert, model_ckpt)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-4-c6bf0359e4b6> in <module>
     12 
---> 13 bert.load_bert_weights(l_bert, model_ckpt)

C:\ProgramData\Anaconda3\lib\site-packages\bert\loader.py in load_stock_weights(bert, ckpt_path, map_to_stock_fn)
    210     stock_weights = set(ckpt_reader.get_variable_to_dtype_map().keys())
--> 211     prefix = bert_prefix(bert)
    212 
    213     loaded_weights = set()

C:\ProgramData\Anaconda3\lib\site-packages\bert\loader.py in bert_prefix(bert)
    188     re_bert = re.compile(r'(.*)/(embeddings|encoder)/(.+):0')#here we get a list index out of range
--> 189     match = re_bert.match(bert.weights[0].name)
    190     assert match, "Unexpected bert layer: {} weight:{}".format(bert, bert.weights[0].name)
    191     prefix = match.group(1)

IndexError: list index out of range
yangxudong commented 3 years ago

I have the same problem too. Have you solved this problem?

kpe commented 3 years ago

Yes, please note the comment in the examples from the readme:

# use in Keras Model here, and call model.build()

for example, you could try:

# use in Keras Model here, and call model.build()
model = keras.models.Sequential([
    keras.layers.InputLayer(input_shape=(128,)),
    l_bert,
    keras.layers.Lambda(lambda x: x[:, 0, :]),
    keras.layers.Dense(2)
])
model.build(input_shape=(None, 128))

without model.build() no weights get instantiated by keras, therefore the checkpoint cannot be loaded.

Because this is such a common problem when using bert-for-tf2 for the first time, I'll add a descriptive assertion like:

AssertionError: BertModelLayer weights have not been instantiated yet. Please add the layer in a Keras model and call model.build() first!