utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 342 forks source link

fastbert wont output sigmoid for multilabel task #224

Closed hy951221 closed 4 years ago

hy951221 commented 4 years ago

Hello, I am new to bert and NN. I followed the tutorial and was able to produce nice results for multi-class tasks. However, I was only able to get softmax output but not sigmoid for each label from the predictions even though I set multi_label=True for multilabel task. Thanks for your help in advance!! [[('Easy to set up/use(pos)', 0.39326855540275574), ('Good control(pos)', 0.19075249135494232), ('Works well with echo/google home(pos)', 0.15823547542095184), ('Overall functionality(pos)', 0.15642541646957397), ('Timer/Programing(pos)', 0.08805117011070251), ('Good connect/pair(neg)', 0.06877072155475616), ('Easy to set up/use(neg)', 0.05767549201846123), ('Price(pos)', 0.05100243166089058), ('Good control(neg)', 0.04893764108419418), ('Stability/quick response(pos)', 0.04889653995633125), ('Durability(pos)', 0.04828120023012161), ('Customer service(pos)', 0.045572679489851), ('Retain/ sync memory(neg)', 0.04357598349452019), ('Durability(neg)', 0.04130781069397926), ('Good connect/pair(pos)', 0.04110244661569595), ('Works well with echo/google home(neg)', 0.04066279157996178), ('Packaging/IM(pos)', 0.03957880660891533), ('Aesthetic(neg)', 0.039149072021245956), ('5G support(pos)', 0.039017967879772186),

hy951221 commented 4 years ago

This is how I instantiate the databunch and learner obj: databunch = BertDataBunch(DATA_PATH,LABEL_PATH, tokenizer='bert-base-uncased', train_file='/content/drive/My Drive/Smart BUZZ/train.csv', val_file='/content/drive/My Drive/Smart BUZZ/test.csv', label_file='/content/drive/My Drive/Smart BUZZ/labels.csv', text_col='Review Text', label_col=list(train.columns[1:]), batch_size_per_gpu=16, max_seq_length=128, multi_gpu=True, multi_label=True, model_type='bert')

learner = BertLearner.from_pretrained_model( databunch, pretrained_path='bert-base-uncased', metrics=metrics, device=device_cuda, logger=logger, output_dir=MODEL_PATH, finetuned_wgts_path=None, warmup_steps=50, multi_gpu=True, is_fp16=False, multi_label=True, logging_steps=0)