utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 342 forks source link

Does the multi-label classifier optimize the entire model or it freezes BERT? #188

Open fabrahman opened 4 years ago

fabrahman commented 4 years ago

I wonder if the default code optimizes the entire model end-to-end or just the additional classifier layer parameters get updated?

Thanks

aaronbriel commented 4 years ago

I noticed the freeze/requires_grad logic is commented in learner_cls. @kaushaltrivedi how does one freeze the bert layers and only train the added custom layer?

aaronbriel commented 4 years ago

I tried adding a freeze_transformers_layer conditional in BertLearner's init function that set requires_grad to false for any param with the model_type in named_parameters but it didn't seem to have any effect. This approach worked for me in another implementation, resulting in exponentially reduced training times.

aaronbriel commented 4 years ago

https://github.com/kaushaltrivedi/fast-bert/pull/195

lingdoc commented 4 years ago

just for clarification, freezing the layers means that there is a single linear layer (i.e. classifier head?) being trained for classification rather than all the layers being used to update the weights (of the input text) during training?

aaronbriel commented 4 years ago

That is correct. Also note that I did indeed confirm this by looping through and printing all layer names along with their requires_grad setting (summarized for brevity):

bert.embeddings.word_embeddings.weight, requires_grad:False bert.embeddings.position_embeddings.weight, requires_grad:False bert.embeddings.token_type_embeddings.weight, requires_grad:False ... bert.encoder.layer.11.output.LayerNorm.bias, requires_grad:False bert.pooler.dense.weight, requires_grad:False bert.pooler.dense.bias, requires_grad:False classifier.weight, requires_grad:True classifier.bias, requires_grad:True