Example with VGG16, would be interesting to see the speedup on segmentation model.
import torch
from torchvision.models import vgg16
from tqdm import tqdm
from baal.bayesian.caching_utils import MCCachingModule
from baal.bayesian.dropout import MCDropoutModule
vgg = vgg16().cuda()
vgg.eval()
input = torch.randn(10, 3, 224, 224).cuda()
# Regular: 1:49
# Cached : 20 seconds
with MCCachingModule(vgg) as model:
with MCDropoutModule(model) as model_2:
[model_2(input).detach().cpu() for _ in tqdm(range(1000))]
Example with VGG16, would be interesting to see the speedup on segmentation model.
TODO: