Open asy51 opened 9 months ago
I think the models were obtained with inception_v3(aux_logits=False)
(i.e. without the Auxiliary Classifiers).
Using torch
v 2.1.1 and torchvision
v 0.16.1 running
class Backbone(torch.nn.Module):
def __init__(self, path):
super().__init__()
base_model = inception_v3(pretrained=False, aux_logits=False)
encoder_layers = list(base_model.children())
self.backbone = nn.Sequential(*encoder_layers[:-1])
state_dict = torch.load(path)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k[9:]] = v
print(self.backbone.load_state_dict(new_state_dict)) # <All keys matched successfully>
def forward(self, x):
return self.backbone(x)
backbone = Backbone("path/to/InceptionV3.pt")
Works for me
I am getting state_dict mismatches:
Error msg:
torch and torchvision versions: ('2.0.0+cu117', '0.15.1+cu117') Doesn't seem like conda
environment.yaml
or piprequirements.txt
files are available Please advise on how to load the weights! 🙏