baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
862 stars 86 forks source link

Add caching mechanism for MCDropout #254

Closed Dref360 closed 1 year ago

Dref360 commented 1 year ago

Example with VGG16, would be interesting to see the speedup on segmentation model.

import torch
from torchvision.models import vgg16
from tqdm import tqdm

from baal.bayesian.caching_utils import MCCachingModule
from baal.bayesian.dropout import MCDropoutModule

vgg = vgg16().cuda()
vgg.eval()

input = torch.randn(10, 3, 224, 224).cuda()

# Regular: 1:49
# Cached : 20 seconds
with MCCachingModule(vgg) as model:
    with MCDropoutModule(model) as model_2:
        [model_2(input).detach().cpu() for _ in tqdm(range(1000))]

TODO:

Dref360 commented 1 year ago

Need to add documentation, but ready for review.