vertaix / Vendi-Score

MIT License
97 stars 8 forks source link

Bug when running code example #6

Closed EyalMichaeli closed 1 year ago

EyalMichaeli commented 1 year ago

Tried to run:

from torchvision import datasets
from vendi_score import image_utils

mnist = datasets.MNIST("data/mnist", train=False, download=True)
digits = [[x for x, y in mnist if y == c] for c in range(10)]
pixel_vs = [image_utils.pixel_vendi_score(imgs) for imgs in digits]
# The default embeddings are from the pool-2048 layer of the torchvision
# Inception v3 model.
inception_vs = [image_utils.embedding_vendi_score(imgs, device="cuda") for imgs in digits]
for y, (pvs, ivs) in enumerate(zip(pixel_vs, inception_vs)): print(f"{y}\t{pvs:.02f}\t{ivs:02f}")

# Output:
# 0       7.68    3.45
# 1       5.31    3.50
# 2       12.18   3.62
# 3       9.97    2.97
# 4       11.10   3.75
# 5       13.51   3.16
# 6       9.06    3.63
# 7       9.58    4.07
# 8       9.69    3.74
# 9       8.56    3.43

and resulted in a bug due to the fact that it's pretrained, and you ask it to init the weights (which doesn't make sense).

fixed by setting init_weights=False in get_inception func.

BTW: also fixed warning of using an older version of pytorch param name: "pretrained" instead of weights. So, the new function looke like this:

def get_inception(pretrained=True, pool=True):
    if pretrained:
        model = inception_v3(
            weights=Inception_V3_Weights.DEFAULT, transform_input=True, init_weights=False
        ).eval()
    else:
        model = inception_v3(
            transform_input=True, init_weights=True
        ).eval()
    if pool:
        model.fc = nn.Identity()
    return model
adjidieng commented 1 year ago

Fixed with the pull request merge.