baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
854 stars 84 forks source link

unpatch_module does not seem to be working #246

Closed nitish1295 closed 1 year ago

nitish1295 commented 1 year ago

Describe the bug unpatch_module does not seem to be working as expected. Once dropouts are enabled I can't switch them off using unpatch_module

To Reproduce

Run this in Colab

!pip install -qq baal transformers datasets

import torch
from transformers import BertForSequenceClassification
from baal.bayesian.dropout import patch_module,unpatch_module

pretrained_weights = 'bert-base-uncased'

use_cuda = torch.cuda.is_available()

model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
print(f"Droputs enabled: {model.dropout.training}") # False here

model = patch_module(model,inplace=False)
print(f"Droputs enabled: {model.dropout.training}") # True here

model = unpatch_module(model,inplace=False)
print(f"Droputs enabled: {model.dropout.training}") # Should be False here but this is true??

# For vanilla NN I used to do the following, not sure if this is relevant for BERT. Any ideas?
for m in model.modules():
  if m.__class__.__name__.startswith('Dropout'):
    m.eval()

print(f"Droputs enabled: {model.dropout.training}") # False here

Expected behavior

model = unpatch_module(model,inplace=False)
print(f"Droputs enabled: {model.dropout.training}") # <-- This should return false

Once unpatch_module is run dropouts should not be in training mode

Vesion (please complete the following information):

Additional context So based on my understanding while using MCdropout with vanilla Neural nets I frequently used model.dropout.training to make sure if dropouts are enabled or not. If this holds for HF BERT models(which I think it does since it is NN based) then essentially this is a bug.

Dref360 commented 1 year ago

By default, a Module has training=True. I think HuggingFace sets the model to eval automatically?

If we call eval on the resulting model we get what you expect (see below).

We should probably assign training base on the status of the old Module. Good catch!

import torch
from transformers import BertForSequenceClassification
from baal.bayesian.dropout import patch_module,unpatch_module

pretrained_weights = 'bert-base-uncased'

use_cuda = torch.cuda.is_available()

model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
print(f"Droputs enabled: {model.dropout.training}") # False here

model = patch_module(model,inplace=False)
print(f"Droputs enabled: {model.dropout.training}") # True here

model = unpatch_module(model,inplace=False).eval()
print(f"Droputs enabled: {model.dropout.training}") # True here