utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 342 forks source link

Parameterized weights for loss functions #236

Closed aaronbriel closed 3 years ago

aaronbriel commented 4 years ago

Added pos_weight and weight to BertDataBunch constructor and to from_pretrained_model of BertLearner class, initializing said attributes of multilabel classifiers in model_class. Passed attributes to BCEWithLogitsLoss instantiations in modeling multilabel classification class forward functions.

Accuracy and confusion matrix observed in non multilabel run with master branch using sample weighting scheme with custom sampler but no pos_weights or weights parameter consistent with prior runs.

Implementation of class weight demonstrated improved accuracy and precision over sample weighting.

aaronbriel commented 4 years ago

@kaushaltrivedi I resolved merge conflicts - anyway to look this over?

kaushaltrivedi commented 3 years ago

merged