Closed ashdtu closed 6 months ago
I have a quick first pass at point 1 in a fork, based on how rust-bert
handles it: https://github.com/bkonkle/burn-models/blob/707153f5ef1f1f2e8478cebf45e3ca58247d8348/bert-burn/src/model.rs#L168-L171
How would I approach point 2?
I'm making more progress on the very beginnings of a transformers-style library for Burn using traits for pipeline implementations, but in my WIP testing so far I'm having trouble with learning not working correctly. It doesn't seem to be using the pre-trained weights form bert-base-uncased
correctly, so accuracy fluctuates around 25% to 50%.
https://github.com/bkonkle/burn-transformers
This is using my branch with pooled Bert output. The branch doesn't currently build, but I plan to do more work on it this week to fix that and get a good example in place for feedback.
Awesome @bkonkle! I think the current implementation is using RoBERTa weights instead of BERT, so maybe this isn't compatible with the BERT weights for the classification head. Not sure if this helps, but if you find something not working, make sure to test multiple backends and report a bug if there are differences.
Okay, I believe I understand goal 2 better now. I was thinking this meant a flag for all_hidden_states
, like this flag in Huggingface's transformers library. I now believe that this means just the full Tensor from the last hidden_states
value, like this property in Huggingface's transformers. This would correspond with the final x
value in Burn's Transformer Encoder, here.
If my interpretation is correct, I believe the approach in my fork here addresses this by returning both the last hidden states and the optional pooled output if available.
Update: Solved - see the next comment below.
The learning rate was indeed a hint. I had it set way too low, based on the default value in the JointBERT repo I was learning from. :sweat_smile: Setting the learning rate to 1e-2
solves my problem, so I think my branch is ready for some review to see if this is the right approach to enabling custom BERT model outputs. :+1:
======================== Learner Summary ======================== Model: Model[num_params=109489161] Total Epochs: 10
Split | Metric | Min. | Epoch | Max. | Epoch |
---|---|---|---|---|---|
Train | Loss | 0.003 | 10 | 0.246 | 1 |
Train | Accuracy | 92.540 | 1 | 99.900 | 10 |
Train | Learning Rate | 0.000 | 10 | 0.000 | 1 |
Valid | Loss | 0.060 | 10 | 0.109 | 8 |
Valid | Accuracy | 96.900 | 8 | 98.100 | 10 |
For flexibility to fine-tune in downstream tasks, we should have the following options in the BERT family model outputs: