Closed monajalal closed 3 years ago
changed line 118 of type_specific_network.py
from:
norm = torch.norm(masked_embedding, p=2, dim=2) + 1e-10
to:
norm = torch.norm(masked_embedding, p=2, dim=2, keepdim=True) + 1e-10
This is how the result of running it looks like:
[jalal@goku fashion-compatibility]$ source /scratch3/venv/fashcomp/bin/activate
(fashcomp) [jalal@goku fashion-compatibility]$ python main.py --test --l2_embed --resume runs/nondisjoint_l2norm/model_best.pth.tar --datadir ../../../data/fashion/
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
=> loaded checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar' (epoch 5)
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
test set: Compat AUC: 0.88 FITB: 57.6
@monajalal I dont think we should use this keepdim=True This results in the embedding shape itself changing for eg: 66 to 67
Regarding https://github.com/mvasil/fashion-compatibility/pull/13
If you run the code in the current stable version of PyTorch 1.9, we will have: