Model | heads | Params (M) | Acc (%) |
---|---|---|---|
ResNet50 baseline (ref) | 23.5M | 93.62 | |
BoTNet-50 | 1 | 18.8M | 95.11% |
BoTNet-50 | 4 | 18.8M | 95.78% |
BoTNet-S1-50 | 1 | 18.8M | 95.67% |
BoTNet-S1-59 | 1 | 27.5M | 95.98% |
BoTNet-S1-77 | 1 | 44.9M | wip |
Model
from model import Model
model = ResNet50(num_classes=1000, resolution=(224, 224))
x = torch.randn([2, 3, 224, 224])
print(model(x).size())
Module
from model import MHSA
resolution = 14
mhsa = MHSA(planes, width=resolution, height=resolution)