facebookresearch / VLPart

[ICCV2023] VLPart: Going Denser with Open-Vocabulary Part Segmentation
MIT License
348 stars 16 forks source link

How to save new zero shot weight? #8

Closed joshmyersdean closed 9 months ago

joshmyersdean commented 9 months ago

Hello!

What weight are you saving for the zero shot weights?

joshmyersdean commented 9 months ago

Ah found the line here! https://github.com/facebookresearch/VLPart/blob/main/demo/predictor.py#L23

joshmyersdean commented 9 months ago

Hi @PeizeSun,

For the zero-shot weights, do you just run each part of the vocabulary with CLIP and then save the resulting matrix?

joshmyersdean commented 9 months ago

Update: When I do the above I get different weights than what are saved.

form = lambda x: f"A photo of a {' '.join(x.split(':'))}"

embeds = torch.zeros(77,1024)
with torch.no_grad():
    for idx,i in enumerate(PASCAL_PART_BASE_CATEGORIES):
        embeds[idx,:] = text_encoder(text_encoder.tokenize(form(i['name'])))

saved_embeds = np.load("datasets/metadata/pascal_part_base_clip_RN50_a+cname.npy")

print(embeds)
tensor([[-0.0448,  0.0072, -0.1996,  ...,  0.0097, -0.2787,  0.3235],
        [-0.0216,  0.2441, -0.2372,  ..., -0.1914, -0.1295,  0.1115],
        [ 0.0405, -0.1673, -0.1807,  ..., -0.2903, -0.2093,  0.0923],
        ...,
        [ 0.0482, -0.1842, -0.0162,  ..., -0.3602,  0.1764,  0.0118],
        [ 0.0588, -0.1146, -0.0537,  ..., -0.1661, -0.0903,  0.0582],
        [ 0.4419,  0.0859,  0.1314,  ..., -0.0912, -0.3575, -0.1130]])

print(saved_embeds)
array([[-0.2603  ,  0.1097  , -0.2805  , ..., -0.03482 , -0.1514  ,
         0.3855  ],
       [-0.1748  ,  0.328   , -0.3494  , ..., -0.3804  , -0.0924  ,
         0.141   ],
       [-0.09174 , -0.04358 , -0.2373  , ..., -0.3523  , -0.1426  ,
         0.1539  ],
       ...,
       [-0.01486 , -0.1587  , -0.11475 , ..., -0.4219  ,  0.1277  ,
         0.1537  ],
       [-0.0338  , -0.004128, -0.2137  , ..., -0.3127  , -0.04684 ,
         0.1978  ],
       [ 0.2861  ,  0.1469  , -0.0263  , ..., -0.2102  , -0.3286  ,
         0.03644 ]], dtype=float16)