bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.37k stars 929 forks source link

继续pretrain的时候,如果采用单机多卡来训练,无法加载checkpoint。 #480

Closed KY0coder closed 2 years ago

KY0coder commented 2 years ago

在进行GAU pretrain的时候,如果采用单机多卡来训练,无法加载checkpoint

代码如下: strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:3"]) with strategy.scope(): base = build_transformer_model( config_path, checkpoint_path=checkpoint_path, model=GAU_alpha, with_mlm='linear', return_keras_model=False ) model = base.model

错误如下: Traceback (most recent call last): File "pretrain.py", line 224, in dataset, steps_per_epoch=1000, epochs=epochs, callbacks=[evaluator] File "/home/amy/anaconda3/envs/zhangkaiyuan_tf_gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 649, in fit validation_freq=validation_freq) File "/home/amy/anaconda3/envs/zhangkaiyuan_tf_gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_distributed.py", line 143, in fit_distributed steps_name='steps_per_epoch') File "/home/amy/anaconda3/envs/zhangkaiyuan_tf_gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 274, in model_iteration batch_outs = f(actual_inputs) File "/home/amy/anaconda3/envs/zhangkaiyuan_tf_gpu/lib/python3.6/site-packages/tensorflow/python/keras/backend.py", line 3292, in call run_metadata=self.run_metadata) File "/home/amy/anaconda3/envs/zhangkaiyuan_tf_gpu/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1458, in call run_metadata_ptr) tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable Transformer-1-GatedAttentionUnit/scale_offset_2/gamma from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/Transformer-1-GatedAttentionUnit/scale_offset_2/gamma/N10tensorflow3VarE does not exist. [[{{node Transformer-1-GatedAttentionUnit_1/scale_offset_2/mul/ReadVariableOp}}]]

尝试如果单卡来进行训练,就不会报错; 模型用 keras 自带的 load_weights 来加载权重也没有问题。 请问苏神,问题是出在bert4keras的 load_weights_as_checkpoint 这里吗?

KY0coder commented 2 years ago

啊,仔细看了一下load_weights_as_checkpoint 代码,已经解决了。。