Closed arthur-thuy closed 12 months ago
Sorry for the late answer.
It is quite easy to "temporarily unpatch the model", so I would go this route.
Something like:
model = ...
trainer = BaalTransformersTrainer(...)
with patch_module(model) as _mc_dropout_model:
# The model is stochastic
trainer.train()
# The model is deterministic
trainer.evaluate()
Let me know if that works for you, if not we can probably implement your suggestions.
Hi @Dref360,
I tried creating an MWE (see this gist) but I get an error AttributeError: __enter__
because of the with
context. The gist is adapted from the experiments/nlp_bert_mcdropout.py
Baal example.
Here is the full output of the script:
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
========== Without MC Dropout ==========
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00, 6.38it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00, 6.42it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00, 6.36it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00, 6.42it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:02<00:00, 6.40it/s]
[0.4908256880733945, 0.4908256880733945, 0.4908256880733945, 0.4908256880733945, 0.4908256880733945]
========== With MC Dropout ==========
Traceback (most recent call last):
File "/home/abthuy/Documents/PhD research/active-uncertainty/src/check_unpatch", line 131, in <module>
main()
File "/home/abthuy/Documents/PhD research/active-uncertainty/src/check_unpatch", line 105, in main
with patch_module(model) as _mc_dropout_model:
AttributeError: __enter__
Do you have an idea how this could be fixed? I also tried running patch_module
on trainer.model
instead of on model
but that results in the same error.
Thank you
Ah my bad, you would need to use MCDropoutModule as shown here
with MCDropoutModule(model) as _mc_dropout_model:
# The model is stochastic
trainer.train()
Okay thanks, it works now! I wasn't aware of any difference between patch_module
and MCDropoutModule
; I thought they could be used interchangeably.
Describe the bug The HuggingFace + Baal example scripts (e.g. nlp_bert_mcdropout.py and this gist) use the
BaalTransformersTrainer
class. MC Dropout is activated and the model's performance is evaluated with the HFevaluate()
function.The
evaluate()
function only draws 1 forward pass while MC Dropout is activated. I think this is incorrect as you're essentially predicting with reduced model capacity. When only predicting 1 sample, MC Dropout should be deactivated. Alternatively, with MC Dropout activated, multiple forward passes can be drawn to get the predictive distribution.To Reproduce
/
Expected behavior It would be nice to have an
iterations
argument in theevaluate()
function. Wheniterations ==1
, the model is unpatched; wheniterations > 1
, multiple samples are drawn withpredict_on_dataset
like in the gist.I guess the
BaalTransformersTrainer
'sevaluate()
function needs to be overridden. I'm not sure whether the model could be temporarily unpatched. Metrics should also be handled like the HFevaluate()
function currently does.Version (please complete the following information):
Additional context /
Do you think this is feasible?
Thank you