baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
854 stars 84 forks source link

#246 Fix issue where training was not kept consistent #249

Closed Dref360 closed 1 year ago

Dref360 commented 1 year ago

Summary:

Fixes #246

Features:

import torch
from transformers import BertForSequenceClassification
from baal.bayesian.dropout import patch_module,unpatch_module

pretrained_weights = 'bert-base-uncased'
use_cuda = torch.cuda.is_available()

model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
print(f"Droputs enabled: {model.dropout.training}, model= {model.training}") # False here

>>> Droputs enabled: False, model= False

model = patch_module(model,inplace=False)
print(f"Droputs enabled: {model.dropout.training}, model= {model.training}") # True here

>>> Droputs enabled: False, model= False

model = unpatch_module(model,inplace=False)
print(f"Droputs enabled: {model.dropout.training}, model= {model.training}") # Should be False here but this is true??

>>> Droputs enabled: False, model= False

Checklist: