utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.86k stars 341 forks source link

BertLearner.from_pretrained_model() shows error: name 'pos_weight' is not defined #273

Open steventu27 opened 3 years ago

steventu27 commented 3 years ago

I tired the same code in https://github.com/kaushaltrivedi/fast-bert/blob/master/test/multi_class.ipynb

there is an error when i run learner = BertLearner.from_pretrained_model(databunch,pretrained_path=BERT_PRETRAINED_PATH,metrics=metrics, device=device, logger=logger, output_dir=args.output_dir, finetuned_wgts_path=FINETUNED_PATH, warmup_steps=args.warmup_steps, multi_gpu=args.multi_gpu, is_fp16=args.fp16, multi_label=True, logging_steps=0)

Here is the error NameErrorTraceback (most recent call last)

in 3 finetuned_wgts_path=FINETUNED_PATH, warmup_steps=args.warmup_steps, 4 multi_gpu=args.multi_gpu, is_fp16=args.fp16, ----> 5 multi_label=True, logging_steps=0) ~/anaconda3/lib/python3.7/site-packages/fast_bert/learner_cls.py in from_pretrained_model(dataBunch, pretrained_path, output_dir, metrics, device, logger, finetuned_wgts_path, multi_gpu, is_fp16, loss_scale, warmup_steps, fp16_opt_level, grad_accumulation_steps, multi_label, max_grad_norm, adam_epsilon, logging_steps, freeze_transformer_layers, pos_weight, weight) 194 195 model = load_model( --> 196 dataBunch, pretrained_path, finetuned_wgts_path, device, multi_label 197 ) 198 ~/anaconda3/lib/python3.7/site-packages/fast_bert/learner_cls.py in load_model(dataBunch, pretrained_path, finetuned_wgts_path, device, multi_label) 144 config_class, model_class, _ = MODEL_CLASSES[model_type] 145 --> 146 model_class[1].pos_weight = pos_weight 147 model_class[1].weight = weight 148 NameError: name 'pos_weight' is not defined Do you know how to fix this? Thanks!!
JiaWang-seek commented 3 years ago

I had the same issue! I have tried to roll back to an old version 1.9.1, it works for me.

pip install fast_bert==1.9.1

steventu27 commented 3 years ago

thanks! I solved it by download this package and load it.