google-research / electra

ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
Apache License 2.0
2.31k stars 351 forks source link

Error when pretraining on TPU: `Malformed device specification` #67

Closed danyaljj closed 4 years ago

danyaljj commented 4 years ago

I am trying to pretrain electra on my own data and getting Malformed device specification. Wondering if you have any thoughts on what could be the cause of this error. Here is the complete error log:

$ python3 run_pretraining.py     --data-dir gs://danielk-files/fa-text/pretraining_data_electra/     --model-name electra-small
================================================================================
Config:
================================================================================
debug False
disallow_correct False
disc_weight 50.0
do_eval False
do_lower_case True
do_train True
electra_objective True
embedding_size 128
eval_batch_size 128
gcp_project ai2-tpu
gen_weight 1.0
generator_hidden_size 0.25
generator_layers 1.0
iterations_per_loop 200
keep_checkpoint_max 20
learning_rate 0.0005
lr_decay_power 1.0
mask_prob 0.15
max_predictions_per_seq 19
max_seq_length 128
model_dir gs://danielk-files/fa-text/pretraining_data_electra/models/electra-small
model_hparam_overrides {}
model_name electra-small
model_size small
num_eval_steps 100
num_tpu_cores 8
num_train_steps 1000000
pretrain_tfrecords gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord*
results_pkl gs://danielk-files/fa-text/pretraining_data_electra/models/electra-small/results/unsup_results.pkl
results_txt gs://danielk-files/fa-text/pretraining_data_electra/models/electra-small/results/unsup_results.txt
save_checkpoints_steps 1000
temperature 1.0
tpu_job_name my tpu job
tpu_name danielk-tpu-europe-west4-a-v3-8-1-new
tpu_zone europe-west4-a
train_batch_size 128
uniform_generator False
untied_generator True
untied_generator_embeddings False
use_tpu True
vocab_file gs://danielk-files/fa-text/pretraining_data_electra/vocab.txt
vocab_size 30522
weight_decay_rate 0.01

================================================================================
Running training
================================================================================
2020-06-07 23:00:13.709944: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:370] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.
2020-06-07 23:00:14.024300: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-06-07 23:00:14.024351: E tensorflow/stream_executor/cuda/cuda_driver.cc:318] failed call to cuInit: UNKNOWN ERROR (303)
2020-06-07 23:00:14.024375: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (danielk-tpu-vm-europe-west4-a): /proc/driver/nvidia/version does not exist
Model is built!
ERROR:tensorflow:Error recorded from training_loop: Malformed device specification '/job:my tpu job /task:0/device:CPU:0' in node: {name:'input_pipeline_task0/Const' id:19 op device:{} def:{node input_pipeline_task0/Const (defined at /home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorf
low_core/python/framework/ops.py:1748)  = Const[_xla_inferred_shapes=[[143]], dtype=DT_STRING, value=Tensor<type: string shape: [143] values: gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-0-of-1000 gs://danielk-files/fa-text/pretraining_data_electra/pretrai
n_tfrecords/pretrain_data.tfrecord-105-of-1000 gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-112-of-1000...>, _device="/job:my tpu job /task:0/device:CPU:0"]()}}
         [[input_pipeline_task0/Const]]
Traceback (most recent call last):
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
    return fn(*args)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1348, in _run_fn
    self._extend_graph()
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1388, in _extend_graph
    tf_session.ExtendSession(self._session)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Malformed device specification '/job:my tpu job /task:0/device:CPU:0' in node: {name:'input_pipeline_task0/Const' id:19 op device:{} def:{{{node input_pipeline_task0/Const}} = Const[_xla_inferred_shapes=[[143]], dtype=DT_STRING, value=Tensor<type: string shape: [143] values: gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-0-of-1000 gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-105-of-1000 gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-112-of-1000...>, _device="/job:my tpu job /task:0/device:CPU:0"]()}}
         [[input_pipeline_task0/Const]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "run_pretraining.py", line 385, in <module>
    main()
  File "run_pretraining.py", line 381, in main
    args.model_name, args.data_dir, **hparams))
  File "run_pretraining.py", line 344, in train_or_eval
    max_steps=config.num_train_steps)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3035, in train
    rendezvous.raise_errors()
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py", line 136, in raise_errors
    six.reraise(typ, value, traceback)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/six.py", line 703, in reraise
    raise value
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3030, in train
    saving_listeners=saving_listeners)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 370, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1161, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1195, in _train_model_default
    saving_listeners)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1490, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 584, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1014, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 725, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1207, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1212, in _create_session
    return self._sess_creator.create_session()
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 878, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 647, in create_session
    init_fn=self._scaffold.init_fn)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/training/session_manager.py", line 296, in prepare_session
    sess.run(init_op, feed_dict=init_feed_dict)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1180, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1359, in _do_run
    run_metadata)
  File "/home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Malformed device specification '/job:my tpu job /task:0/device:CPU:0' in node: {name:'input_pipeline_task0/Const' id:19 op device:{} def:{node input_pipeline_task0/Const (defined at /home/danielk/anaconda3/envs/py37_electra/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1748)  = Const[_xla_inferred_shapes=[[143]], dtype=DT_STRING, value=Tensor<type: string shape: [143] values: gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-0-of-1000 gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-105-of-1000 gs://danielk-files/fa-text/pretraining_data_electra/pretrain_tfrecords/pretrain_data.tfrecord-112-of-1000...>, _device="/job:my tpu job /task:0/device:CPU:0"]()}}
         [[input_pipeline_task0/Const]]

Note that I have already updated the configure_pretraining.py:

(py37_electra) danielk@danielk-tpu-vm-europe-west4-a:~/electra$ git diff
diff --git a/configure_pretraining.py b/configure_pretraining.py
index f576563..1e87b36 100644
--- a/configure_pretraining.py
+++ b/configure_pretraining.py
@@ -48,7 +48,7 @@ class PretrainingConfig(object):
     self.save_checkpoints_steps = 1000
     self.num_train_steps = 1000000
     self.num_eval_steps = 100
-    self.keep_checkpoint_max = 5 # maximum number of recent checkpoint files to keep;
+    self.keep_checkpoint_max = 20 # maximum number of recent checkpoint files to keep;
                                  # change to 0 or None to keep all checkpoints

     # model settings
@@ -81,12 +81,12 @@ class PretrainingConfig(object):
     self.eval_batch_size = 128

     # TPU settings
-    self.use_tpu = False
-    self.num_tpu_cores = 1
-    self.tpu_job_name = None
-    self.tpu_name = None  # cloud TPU to use for training
-    self.tpu_zone = None  # GCE zone where the Cloud TPU is located in
-    self.gcp_project = None  # project name for the Cloud TPU-enabled project
+    self.use_tpu = True
+    self.num_tpu_cores = 8
+    self.tpu_job_name = "my tpu job "
+    self.tpu_name = "danielk-tpu-europe-west4-a-v3-8-1-new"  # cloud TPU to use for training
+    self.tpu_zone = "europe-west4-a"  # GCE zone where the Cloud TPU is located in
+    self.gcp_project = "ai2-tpu"  # project name for the Cloud TPU-enabled project

     # default locations of data files
     self.pretrain_tfrecords = os.path.join(

For completeness, here is the TPU I use:

TPU type v3-8
TPU software version 1.15

@stefan-it @clarkkev wondering if you have any thoughts on this issue.

stefan-it commented 4 years ago

Hi @danyaljj , try to set self.tpu_job_name to None.

Btw: here's the configuration that I've used for training the Turkish ELECTRA model: https://github.com/stefan-it/turkish-bert/blob/master/electra/configure_pretraining_base.py#L83-L89

I hope this help :)

danyaljj commented 4 years ago

Hi Stephan, thanks for the quick response! Looks like that response the issue! Appreciate the help!