Trying to run K-fold cross validation for lcf-bert yields this error message:
Traceback (most recent call last):
File "train_k_fold_cross_val.py", line 315, in
main()
File "train_k_fold_cross_val.py", line 311, in main
ins.run()
File "train_k_fold_cross_val.py", line 181, in run
self._reset_params()
File "train_k_fold_cross_val.py", line 84, in _reset_params
self.model.bert.load_state_dict(self.pretrained_bert_state_dict)
File "/Users/joaoleite/anaconda3/envs/absa/lib/python3.6/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'LCF_BERT' object has no attribute 'bert'
I've noticed the _reset_params function in both train.py and train_k_fold_cross_val.py are exactly the same, apart from this else statement:
So I removed it and it worked just fine. Tested for LCF-BERT and BERT_SPC.
Also, I've noticed that running "pip install -r requirements.txt" installs pytorch==1.10.x", which is incompatible. I've added an upper limit of <=1.4.0 and it works.
Trying to run K-fold cross validation for lcf-bert yields this error message: Traceback (most recent call last): File "train_k_fold_cross_val.py", line 315, in
main()
File "train_k_fold_cross_val.py", line 311, in main
ins.run()
File "train_k_fold_cross_val.py", line 181, in run
self._reset_params()
File "train_k_fold_cross_val.py", line 84, in _reset_params
self.model.bert.load_state_dict(self.pretrained_bert_state_dict)
File "/Users/joaoleite/anaconda3/envs/absa/lib/python3.6/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'LCF_BERT' object has no attribute 'bert'
I've noticed the _reset_params function in both train.py and train_k_fold_cross_val.py are exactly the same, apart from this else statement:
else: self.model.bert.load_state_dict(self.pretrained_bert_state_dict)
So I removed it and it worked just fine. Tested for LCF-BERT and BERT_SPC.
Also, I've noticed that running "pip install -r requirements.txt" installs pytorch==1.10.x", which is incompatible. I've added an upper limit of <=1.4.0 and it works.