google-research / adapter-bert

Apache License 2.0
483 stars 49 forks source link

ValueError: Tensor not found in checkpoint #1

Closed julia320 closed 5 years ago

julia320 commented 5 years ago

Hi, I am trying to implement adapter modules into my copy of BERT, however I am running into problems with adding the layers. As far as I can tell, the model gets built correctly, but when I try to run run_pretraining.py I get the following error: image image image

The problem is that it doesn't know what to do with the adapter layers since they aren't found in the checkpoint file - how can I work around this or get BERT to recognize that I want to add them in?

For reference, this is how I am running the script (I've modified it slightly to include the adapter modules as a flag):

python run_pretraining.py \  
--adapter=True \  
--input_file=/path/to/tfrecord/pretrained_iob2.tfrecord \  
--output_dir=/usr/bert/adapter \  
--do_train=True \  
--do_eval=True \  
--bert_config_file=/path/to/bert/multi_cased_L-12_H-768_A-12/bert_config.json \  
--init_checkpoint=/path/to/bert/multi_cased_L-12_H-768_A-12/bert_model.ckpt \  
--train_batch_size=16                                                                                                   
julia320 commented 5 years ago

I figured it out -

I was using the tensorflow WarmStartSettings to load up the original BERT weight and biases and save time during training. However, I was using the default value for which variables to warm start, like so:

warm_start = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=FLAGS.init_checkpoint)

The default value for vars_to_warm_start is * (all of them), and so it was therefore trying to warm start the variables from the layers I had added, even though they weren't actually initialized yet. So the solution was to convert tf.trainable_variables(init_checkpoint) to a list and pass that into the warm start function as vars_to_warm_start. That way it only tries to load the variables that exist in the checkpoint and doesn't throw this error.