yangheng95 / PyABSA

Sentiment Analysis, Text Classification, Text Augmentation, Text Adversarial defense, etc.;
https://pyabsa.readthedocs.io
MIT License
955 stars 161 forks source link

Fine tuning model on a custom APC dataset #136

Closed BrandonFair closed 2 years ago

BrandonFair commented 2 years ago

I'm having a bit of trouble when fine tuning any of the APC models to my custom APC dataset. The models seems to classify each instance as Neutral. The dataset is named and located in citation/apc_citation.test.txt and citation/apc_citation.train.txt.

Snippet of the data: apc_citations_test

Code I'm using to train:

from pyabsa.functional import Trainer
from pyabsa.functional import APCConfigManager
from pyabsa.functional import ABSADatasetList
from pyabsa.functional import APCModelList

apc_config_english = APCConfigManager.get_apc_config_english()
apc_config_english.model = APCModelList.SLIDE_LCF_BERT
apc_config_english.evaluate_begin = 0
apc_config_english.similarity_threshold = 1
apc_config_english.max_seq_len = 80
apc_config_english.dropout = 0.5
apc_config_english.log_step = 5
apc_config_english.l2reg = 0.0001
apc_config_english.dynamic_truncate = True
apc_config_english.srd_alignment = True

sent_classifier = Trainer(config=apc_config_english,
                          dataset='citation',
                          checkpoint_save_mode=2,
                          auto_device=True
                          ).load_trained_model()

Full stack:

Construct DatasetItem from citation, assign dataset_name=citation...
Loading dataset cache: fast_lsa_t.custom_dataset.dataset.cache
/usr/local/lib/python3.6/dist-packages/pyabsa/core/apc/models/ensembler.py:67: ResourceWarning: unclosed file <_io.BufferedReader name='fast_lsa_t.custom_dataset.dataset.cache'>
  self.train_set, self.valid_set, self.test_set, opt = pickle.load(open(cache_path, mode='rb'))
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at microsoft/deberta-v3-base were not used when initializing DebertaV2Model: ['mask_predictions.dense.weight', 'lm_predictions.lm_head.dense.weight', 'lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.dense.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.dense.bias', 'mask_predictions.LayerNorm.bias', 'mask_predictions.LayerNorm.weight', 'mask_predictions.classifier.bias', 'mask_predictions.classifier.weight', 'lm_predictions.lm_head.bias']
- This IS expected if you are initializing DebertaV2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Output exceeds the [size limit](). Open the full output data[ in a text editor]()

2022-03-02 17:43:01,834 INFO: cuda memory allocated:790928896
2022-03-02 17:43:01,835 INFO: ABSADatasetsVersion:2022.01.29    --> Calling Count:0
2022-03-02 17:43:01,836 INFO: MV:<metric_visualizer.metric_visualizer.MetricVisualizer object at 0x7f3e68ea66a0>    --> Calling Count:0
2022-03-02 17:43:01,837 INFO: PyABSAVersion:1.8.29  --> Calling Count:0
2022-03-02 17:43:01,838 INFO: SRD:3 --> Calling Count:0
2022-03-02 17:43:01,838 INFO: TorchVersion:1.9.0+cu102+cuda10.2 --> Calling Count:0
2022-03-02 17:43:01,839 INFO: TransformersVersion:4.9.1 --> Calling Count:0
2022-03-02 17:43:01,840 INFO: auto_device:True  --> Calling Count:1
2022-03-02 17:43:01,840 INFO: batch_size:16 --> Calling Count:2
2022-03-02 17:43:01,841 INFO: cache_dataset:True    --> Calling Count:1
2022-03-02 17:43:01,842 INFO: cross_validate_fold:-1    --> Calling Count:0
2022-03-02 17:43:01,842 INFO: dataset_file:{'train': ['citation/apc_citation.train.txt'], 'test': ['citation/apc_citation.test.txt'], 'valid': []}  --> Calling Count:0
2022-03-02 17:43:01,843 INFO: dataset_name:custom_dataset   --> Calling Count:2
2022-03-02 17:43:01,844 INFO: dca_layer:3   --> Calling Count:0
2022-03-02 17:43:01,845 INFO: dca_p:1   --> Calling Count:0
2022-03-02 17:43:01,845 INFO: deep_ensemble:False   --> Calling Count:0
2022-03-02 17:43:01,849 INFO: device:cuda:0 --> Calling Count:6
2022-03-02 17:43:01,849 INFO: device_name:Tesla V100-PCIE-32GB  --> Calling Count:0
2022-03-02 17:43:01,850 INFO: dlcf_a:2  --> Calling Count:0
2022-03-02 17:43:01,851 INFO: dropout:0.5   --> Calling Count:1
2022-03-02 17:43:01,852 INFO: dynamic_truncate:True --> Calling Count:0
2022-03-02 17:43:01,852 INFO: embed_dim:768 --> Calling Count:7
2022-03-02 17:43:01,853 INFO: eta:-1    --> Calling Count:0
2022-03-02 17:43:01,853 INFO: evaluate_begin:0  --> Calling Count:0
2022-03-02 17:43:01,854 INFO: hidden_dim:768    --> Calling Count:0
2022-03-02 17:43:01,854 INFO: index_to_label:{0: 'Negative', 1: 'Neutral', 2: 'Positive'}   --> Calling Count:0
2022-03-02 17:43:01,854 INFO: initializer:xavier_uniform_   --> Calling Count:0
2022-03-02 17:43:01,855 INFO: inputs_cols:{'spc_mask_vec', 'left_lcf_vec', 'right_lcf_vec', 'lcf_vec', 'text_bert_indices'} --> Calling Count:2
2022-03-02 17:43:01,855 INFO: l2reg:0.0001  --> Calling Count:0
2022-03-02 17:43:01,856 INFO: label_to_index:{'Negative': 0, 'Neutral': 1, 'Positive': 2}   --> Calling Count:0
2022-03-02 17:43:01,856 INFO: lcf:cdw   --> Calling Count:0
2022-03-02 17:43:01,858 INFO: learning_rate:2e-05   --> Calling Count:0
2022-03-02 17:43:01,858 INFO: log_step:5    --> Calling Count:0
2022-03-02 17:43:01,858 INFO: lsa:False --> Calling Count:0
2022-03-02 17:43:01,859 INFO: max_seq_len:80    --> Calling Count:0
2022-03-02 17:43:01,859 INFO: model:<class 'pyabsa.core.apc.models.fast_lsa_t.FAST_LSA_T'>  --> Calling Count:5
2022-03-02 17:43:01,860 INFO: model_name:fast_lsa_t --> Calling Count:2
2022-03-02 17:43:01,860 INFO: model_path_to_save:checkpoints    --> Calling Count:0
2022-03-02 17:43:01,860 INFO: num_epoch:10  --> Calling Count:0
2022-03-02 17:43:01,861 INFO: optimizer:adam    --> Calling Count:0
2022-03-02 17:43:01,861 INFO: patience:99999    --> Calling Count:2
2022-03-02 17:43:01,862 INFO: polarities_dim:3  --> Calling Count:1
2022-03-02 17:43:01,862 INFO: pretrained_bert:microsoft/deberta-v3-base --> Calling Count:3
2022-03-02 17:43:01,862 INFO: save_mode:2   --> Calling Count:0
2022-03-02 17:43:01,863 INFO: seed:52   --> Calling Count:7
2022-03-02 17:43:01,863 INFO: show_metric:False --> Calling Count:0
2022-03-02 17:43:01,864 INFO: sigma:0.3 --> Calling Count:0
2022-03-02 17:43:01,864 INFO: similarity_threshold:1    --> Calling Count:0
2022-03-02 17:43:01,864 INFO: srd_alignment:True    --> Calling Count:0
2022-03-02 17:43:01,865 INFO: use_bert_spc:True --> Calling Count:0
2022-03-02 17:43:01,865 INFO: use_syntax_based_SRD:False    --> Calling Count:0
2022-03-02 17:43:01,866 INFO: window:lr --> Calling Count:0
2022-03-02 17:43:01,874 INFO: ***** Running training for Aspect Polarity Classification *****
2022-03-02 17:43:01,875 INFO: Training set examples = 4979
2022-03-02 17:43:01,881 INFO: Test set examples = 310
2022-03-02 17:43:01,881 INFO: Total params = 197414415, Trainable params = 197414415, Non-trainable params = 0
2022-03-02 17:43:01,882 INFO: Batch size = 16
2022-03-02 17:43:01,883 INFO: Num steps = 190
100%|██████████| 312/312 [03:18<00:00,  1.58it/s, Epoch:0 | Loss:0.3146 | Test Acc:84.19(max:84.19) | Test F1:30.47(max:30.47)]
100%|██████████| 312/312 [03:33<00:00,  1.46it/s, Epoch:1 | Loss:0.3084 | Test Acc:84.19(max:84.19) | Test F1:30.47(max:30.47)]
100%|██████████| 312/312 [03:16<00:00,  1.59it/s, Epoch:2 | Loss:0.3068 | Test Acc:84.19(max:84.19) | Test F1:30.47(max:30.47)]
100%|██████████| 312/312 [03:38<00:00,  1.43it/s, Epoch:3 | Loss:0.9095 | Test Acc:84.19(max:84.19) | Test F1:30.47(max:30.47)]
100%|██████████| 312/312 [03:15<00:00,  1.59it/s, Epoch:4 | Loss:1.3772 | Test Acc:84.19(max:84.19) | Test F1:30.47(max:30.47)]
 93%|█████████▎| 289/312 [03:19<00:12,  1.91it/s, Epoch:5 | Loss:0.5526 | Test Acc:84.19(max:84.19) | Test F1:30.47(max:30.47)]

Please let me know if I should supply any more information. Any help will be greatly appreciated.

yangheng95 commented 2 years ago

First, I recommend you to set learning_rate=1e-5, l2reg=1e-8 as you use the deberta-ve3-bert.

Then, if the problem is still, please test if other datasets work fine

yangheng95 commented 2 years ago

The final attempt to fix this problem will be sharing a balanced cut of your dataset for debugging

BrandonFair commented 2 years ago

First, I recommend you to set learning_rate=1e-5, l2reg=1e-8 as you use the deberta-ve3-bert.

Then, if the problem is still, please test if other datasets work fine

It works, thank you :). Its also performing quite well.

My dataset is quite small(5000) and very imbalanced(90% neutral). Are there any other parameters that might be worth investigating? Is there another ABSA model that might have better performance? Is it worth trying to first fine-tune the model on another dataset (such as SemEval) and then fine-tune it on my dataset?

Sorry about all the questions, but any answers will be invaluable.

yangheng95 commented 2 years ago

You can try your last hypothesis, and I have no more recommendations. However, you can find more research about ABSA on GitHub, good luck!

pepi99 commented 2 years ago

Hey, Brandon, can I ask you for some help with the process?

enzo-ca commented 2 years ago

Is it worth trying to first fine-tune the model on another dataset (such as SemEval) and then fine-tune it on my dataset?

@BrandonFair did you ever try your last question? If so, what were the results?

BrandonFair commented 1 year ago

Hey, Brandon, can I ask you for some help with the process?

Sorry for the late reply @pepi99. Yes I can help, but it seems like you might have found a solution in another issue.

BrandonFair commented 1 year ago

h trying to first fine-tune the model on another dataset (such as SemEval) and then fine-tune it on my dataset?

From what I recall, it worked best when I only fine tuned on the target corpus @enzo-ca. However, my target corpus and SemEval were from very different domains.