rnajena / bertax_training

Training scripts for BERTax
8 stars 4 forks source link

Attempt for using more gpus to train the model #9

Closed LostSpirit1307 closed 1 year ago

LostSpirit1307 commented 1 year ago

Is there any way to use 2-4 gpus to train this model?

f-kretschmer commented 1 year ago

It should be possible, I will look into it but it might take some time. In general, training on multiple GPUs is handled with mirrored_strategy = tensorflow.distribute.MirroredStrategy() for keras/tensorflow. If you want to test it yourself, the first step would probably be adding a line with mirrored_strategy.scope(): before the call to get/build the model. So for example this would be line 84 in model/bert_nc.py for pre-training the nc-version of BERTax.