baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
862 stars 86 forks source link

HF transformer: evaluate() performs one sample with MC Dropout #279

Closed arthur-thuy closed 12 months ago

arthur-thuy commented 1 year ago

Describe the bug The HuggingFace + Baal example scripts (e.g. nlp_bert_mcdropout.py and this gist) use the BaalTransformersTrainer class. MC Dropout is activated and the model's performance is evaluated with the HF evaluate() function.

The evaluate() function only draws 1 forward pass while MC Dropout is activated. I think this is incorrect as you're essentially predicting with reduced model capacity. When only predicting 1 sample, MC Dropout should be deactivated. Alternatively, with MC Dropout activated, multiple forward passes can be drawn to get the predictive distribution.

To Reproduce

/

Expected behavior It would be nice to have an iterations argument in the evaluate() function. When iterations ==1, the model is unpatched; when iterations > 1, multiple samples are drawn with predict_on_dataset like in the gist.

I guess the BaalTransformersTrainer's evaluate() function needs to be overridden. I'm not sure whether the model could be temporarily unpatched. Metrics should also be handled like the HF evaluate() function currently does.

Version (please complete the following information):

Additional context /

Do you think this is feasible?

Thank you

Dref360 commented 1 year ago

Sorry for the late answer.

It is quite easy to "temporarily unpatch the model", so I would go this route.

Something like:

model = ...
trainer = BaalTransformersTrainer(...)

with patch_module(model) as _mc_dropout_model:
    # The model is stochastic
    trainer.train()

# The model is deterministic
trainer.evaluate()

Let me know if that works for you, if not we can probably implement your suggestions.

arthur-thuy commented 12 months ago

Hi @Dref360,

I tried creating an MWE (see this gist) but I get an error AttributeError: __enter__ because of the with context. The gist is adapted from the experiments/nlp_bert_mcdropout.py Baal example.

Here is the full output of the script:

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
========== Without MC Dropout ==========
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00,  6.38it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00,  6.42it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00,  6.36it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00,  6.42it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00,  6.40it/s]
[0.4908256880733945, 0.4908256880733945, 0.4908256880733945, 0.4908256880733945, 0.4908256880733945]
========== With MC Dropout ==========
Traceback (most recent call last):
  File "/home/abthuy/Documents/PhD research/active-uncertainty/src/check_unpatch", line 131, in <module>
    main()
  File "/home/abthuy/Documents/PhD research/active-uncertainty/src/check_unpatch", line 105, in main
    with patch_module(model) as _mc_dropout_model:
AttributeError: __enter__

Do you have an idea how this could be fixed? I also tried running patch_module on trainer.model instead of on model but that results in the same error.

Thank you

Dref360 commented 12 months ago

Ah my bad, you would need to use MCDropoutModule as shown here

with MCDropoutModule(model) as _mc_dropout_model:
    # The model is stochastic
    trainer.train()
arthur-thuy commented 12 months ago

Okay thanks, it works now! I wasn't aware of any difference between patch_module and MCDropoutModule; I thought they could be used interchangeably.