cybertronai / gradient-checkpointing

Make huge neural nets fit in memory
MIT License
2.7k stars 270 forks source link

TF while loop error #52

Open akikaaa opened 3 years ago

akikaaa commented 3 years ago

I'm trying to apply this awesome tool on BERT model. But it seems doesn's work with TF while loop. The model code is basically same as https://github.com/CLUEbenchmark/CLUENER2020/blob/master/tf_version/modeling.py, except that I add every sqrt(num_hidden_layers) hidden to collections by tf.add_to_collection('checkpoints', layer_output) . When run training, I got this error message: "ValueError: Cannot use 'loss/rnn/while/TensorArrayReadV3/Enter' as input to 'loss/rnn/while/TensorArrayReadV3_1' because 'loss/rnn/while/TensorArrayReadV3/Enter' is in a whileloop. See info log for more details." Would you please help me solve this problem?

SeaOfOcean commented 2 years ago

We have fixed the while_loop error in easy parallel library(EPL): https://github.com/alibaba/EasyParallelLibrary

you can enable gradient checkpoint by

import epl
epl.init(epl.Config({"gradient_checkpoint.type": "collection"}))
epl.set_default_strategy(epl.replicate(1))

model_with_checkpoint()

you can try gradient checkpoint "auto" selection, EPL auto find the entrance of each layer as checkpoint tensors

import epl
epl.init(epl.Config({"gradient_checkpoint.type": "auto"}))
epl.set_default_strategy(epl.replicate(1))

model()