Open akikaaa opened 3 years ago
We have fixed the while_loop error in easy parallel library(EPL): https://github.com/alibaba/EasyParallelLibrary
you can enable gradient checkpoint by
import epl
epl.init(epl.Config({"gradient_checkpoint.type": "collection"}))
epl.set_default_strategy(epl.replicate(1))
model_with_checkpoint()
you can try gradient checkpoint "auto" selection, EPL auto find the entrance of each layer as checkpoint tensors
import epl
epl.init(epl.Config({"gradient_checkpoint.type": "auto"}))
epl.set_default_strategy(epl.replicate(1))
model()
I'm trying to apply this awesome tool on BERT model. But it seems doesn's work with TF while loop. The model code is basically same as https://github.com/CLUEbenchmark/CLUENER2020/blob/master/tf_version/modeling.py, except that I add every sqrt(num_hidden_layers) hidden to collections by tf.add_to_collection('checkpoints', layer_output) . When run training, I got this error message: "ValueError: Cannot use 'loss/rnn/while/TensorArrayReadV3/Enter' as input to 'loss/rnn/while/TensorArrayReadV3_1' because 'loss/rnn/while/TensorArrayReadV3/Enter' is in a whileloop. See info log for more details." Would you please help me solve this problem?