from torchvision import datasets
from vendi_score import image_utils
mnist = datasets.MNIST("data/mnist", train=False, download=True)
digits = [[x for x, y in mnist if y == c] for c in range(10)]
pixel_vs = [image_utils.pixel_vendi_score(imgs) for imgs in digits]
# The default embeddings are from the pool-2048 layer of the torchvision
# Inception v3 model.
inception_vs = [image_utils.embedding_vendi_score(imgs, device="cuda") for imgs in digits]
for y, (pvs, ivs) in enumerate(zip(pixel_vs, inception_vs)): print(f"{y}\t{pvs:.02f}\t{ivs:02f}")
# Output:
# 0 7.68 3.45
# 1 5.31 3.50
# 2 12.18 3.62
# 3 9.97 2.97
# 4 11.10 3.75
# 5 13.51 3.16
# 6 9.06 3.63
# 7 9.58 4.07
# 8 9.69 3.74
# 9 8.56 3.43
and resulted in a bug due to the fact that it's pretrained, and you ask it to init the weights (which doesn't make sense).
fixed by setting init_weights=False in get_inception func.
BTW: also fixed warning of using an older version of pytorch param name: "pretrained" instead of weights.
So, the new function looke like this:
def get_inception(pretrained=True, pool=True):
if pretrained:
model = inception_v3(
weights=Inception_V3_Weights.DEFAULT, transform_input=True, init_weights=False
).eval()
else:
model = inception_v3(
transform_input=True, init_weights=True
).eval()
if pool:
model.fc = nn.Identity()
return model
Tried to run:
and resulted in a bug due to the fact that it's pretrained, and you ask it to init the weights (which doesn't make sense).
fixed by setting init_weights=False in get_inception func.
BTW: also fixed warning of using an older version of pytorch param name: "pretrained" instead of weights. So, the new function looke like this: