mayuefine / c_AMPs-prediction

GNU General Public License v3.0
70 stars 35 forks source link

Error in loading model #7

Open LIUSIFAN20221219 opened 1 year ago

LIUSIFAN20221219 commented 1 year ago

您好,我按照c_AMPs-Prediction.md里BERT Model部分的代码载入已经精调好的bert模型时报错,在sklearn自带的base.py get_params(self, deep) 函数部分: AttributeError: 'BertClassifier' object has no attribute 'bert_model'

请问这个bug怎么修复?

jkwang93 commented 1 year ago

I have met same problem, I solved with this code:

params = state['params']

self.set_params(**params)

self.label_list = params['label_list'] self.bert_model = params['bert_model'] self.num_mlp_hiddens = params['num_mlp_hiddens'] self.num_mlp_layers = params['num_mlp_layers'] self.restore_file = params['restore_file'] self.epochs = params['epochs'] self.max_seq_length = params['max_seq_length'] self.train_batch_size = params['train_batch_size'] self.eval_batch_size = params['eval_batch_size'] self.learning_rate = params['learning_rate'] self.warmup_proportion =params['warmup_proportion'] self.gradient_accumulation_steps = params['gradient_accumulation_steps'] self.fp16 = params['fp16'] self.loss_scale = params['loss_scale'] self.local_rank = params['local_rank'] self.use_cuda = params['use_cuda'] self.random_state = params['random_state'] self.validation_fraction = params['validation_fraction'] self.logfile = params['logfile']

KeyLllll commented 6 months ago

@jkwang93 Hi!Which line did you modify in the code? Could you explain it in detail? Thanks!

cfz1998 commented 3 months ago

@jkwang93 Hi!Which line did you modify in the code? Could you explain it in detail? Thanks!

Maybe sklearn.py#449line

RChGO commented 2 weeks ago

I encountered the same issue. It was caused by the installation of scikit-learn version 0.24 and above. When I downgraded to the version (0.22.1) mentioned by the author, the problem was resolved.