Closed Dref360 closed 1 year ago
Fixes #246
import torch from transformers import BertForSequenceClassification from baal.bayesian.dropout import patch_module,unpatch_module pretrained_weights = 'bert-base-uncased' use_cuda = torch.cuda.is_available() model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=pretrained_weights) print(f"Droputs enabled: {model.dropout.training}, model= {model.training}") # False here >>> Droputs enabled: False, model= False model = patch_module(model,inplace=False) print(f"Droputs enabled: {model.dropout.training}, model= {model.training}") # True here >>> Droputs enabled: False, model= False model = unpatch_module(model,inplace=False) print(f"Droputs enabled: {model.dropout.training}, model= {model.training}") # Should be False here but this is true?? >>> Droputs enabled: False, model= False
tests/documentation_test.py
Summary:
Fixes #246
Features:
Checklist:
tests/documentation_test.py
).