Closed jackn11 closed 2 years ago
The whole model is trained, according to const ps = params(bert_model)
. You can switch to params(bert_model.classifier)
if you only want to train the classifier. The different between _bert_model
and bert_model
is that _bert_model
is on cpu while bert_model
is on GPU. There're only 1 bert model been loaded to the GPU. The problem are probably because your GPU RAM is not big enough for a batch with size 4.
Then whether you can train the classifier layer only is a design problem of your own model. That is totally doable, but might result in different model performance (for better or worse). So that's up to you.
I am trying to train several BERT classifier models in one program but i am running out of GPU RAM by loading too many BERT models with
const _bert_model, wordpiece, tokenizer = pretrain"Bert-uncased_L-12_H-768_A-12"
I am following the CoLA exmaple found here https://github.com/chengchingwen/Transformers.jl/blob/master/example/BERT/cola/train.jl
I am wondering if the train!() function found in the example trains all of the parameters shown in
Flux.params(bert_model)
or if it only trains those found inFlux.params(_bert_model.classifier)
. The reason why this is important is, if only the classifier parameters are modified instead of all bert model parameters, then I can load onepretrain"Bert-uncased_L-12_H-768_A-12"
into RAM, instead of many, and then just train new classifiers (_bert_model.classifier
) for each bert classifier I need. This saves a lot of RAM of not loading in a new full BERT model for each bert classifier needed.Please let me know if the whole bert model is trained with the train!() function or just the classifier parameters.
Thank you,
Jack