zihangJiang / TokenLabeling

Pytorch implementation of "All Tokens Matter: Token Labeling for Training Better Vision Transformers"
Apache License 2.0
426 stars 36 forks source link

How to print the output of the wrong prediction of validation dataset? #19

Closed Williamlizl closed 3 years ago

zihangJiang commented 3 years ago

You may refer to the code here to compare the output (prediction) and the target (ground truth). https://github.com/zihangJiang/TokenLabeling/blob/09bb641b1e8f3e94fa1b6c7180addf4507458541/validate.py#L238-L242

Williamlizl commented 3 years ago

You may refer to the code here to compare the output (prediction) and the target (ground truth).

https://github.com/zihangJiang/TokenLabeling/blob/09bb641b1e8f3e94fa1b6c7180addf4507458541/validate.py#L238-L242

And if I want to get the dir with the prediction , ?

zihangJiang commented 3 years ago

To get the path of the images, you may refer to https://github.com/zihangJiang/TokenLabeling/blob/09bb641b1e8f3e94fa1b6c7180addf4507458541/generate_label.py#L110-L128

Williamlizl commented 3 years ago

To get the path of the images, you may refer to

https://github.com/zihangJiang/TokenLabeling/blob/09bb641b1e8f3e94fa1b6c7180addf4507458541/generate_label.py#L110-L128

Is there no test.py to inference?

zihangJiang commented 3 years ago

To get the path of the images, you may refer to https://github.com/zihangJiang/TokenLabeling/blob/09bb641b1e8f3e94fa1b6c7180addf4507458541/generate_label.py#L110-L128

Is there no test.py to inference?

You can use this colab notebook for inference. It uses VOLO model, but you can simply change the model by from tlt.models import lvvit_s and download the pre-trained model here

Williamlizl commented 3 years ago

To get the path of the images, you may refer to https://github.com/zihangJiang/TokenLabeling/blob/09bb641b1e8f3e94fa1b6c7180addf4507458541/generate_label.py#L110-L128

Is there no test.py to inference?

You can use this colab notebook for inference. It uses VOLO model, but you can simply change the model by from tlt.models import lvvit_s and download the pre-trained model here

from tlt.models import lvvit_s from PIL import Image from tlt.utils import load_pretrained_weights from timm.data import create_transform model = lvvit_s(img_size=384) load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912- 114053-lvvit_s-384/model_best.pth.tar') model.eval() transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct']) image = Image.open('/home/lbc/GitHub/c/train/LV-ViT/validation/1_val/323_l2.jpg') input_image = transform(image).unsqueeze(0) ` RuntimeError Traceback (most recent call last)

in 4 from timm.data import create_transform 5 model = lvvit_s(img_size=384) ----> 6 load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912-114053-lvvit_s-384/model_best.pth.tar') 7 model.eval() 8 transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct']) ~/.local/lib/python3.7/site-packages/tlt/utils/utils.py in load_pretrained_weights(model, checkpoint_path, use_ema, strict, num_classes) 109 def load_pretrained_weights(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000): 110 state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes) --> 111 model.load_state_dict(state_dict, strict=strict) 112 113 ~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 1222 if len(error_msgs) > 0: 1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( -> 1224 self.__class__.__name__, "\n\t".join(error_msgs))) 1225 return _IncompatibleKeys(missing_keys, unexpected_keys) 1226 RuntimeError: Error(s) in loading state_dict for LV_ViT: Missing key(s) in state_dict: "head.weight", "head.bias", "aux_head.weight", "aux_head.bias". `
zihangJiang commented 3 years ago

Please use the latest version of our repo. (pip install tlt==0.2.0) This is a bug of the function in tlt/utils.py in our early version which delete all classification heads in order to do transfer learning.

Williamlizl commented 3 years ago

Please use the latest version of our repo. (pip install tlt==0.2.0) This is a bug of the function in tlt/utils.py in our early version which delete all classification heads in order to do transfer learning.

from tlt.models import lvvit_s from PIL import Image from tlt.utils import load_pretrained_weights from timm.data import create_transform model = lvvit_s(img_size=384) load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912-114053-lvvit_s-384/model_best.pth.tar',strict=False,num_classes=2) model.eval() print(model) transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct']) image = Image.open('/home/lbc/GitHub/c/train/LV-ViT/validation/1_val/323_l2.jpg') input_image = transform(image).unsqueeze(0) If I use model = lvvit_s(img_size=384), it loads the official model, but how to load my finetune model ?

zihangJiang commented 3 years ago

If the number of classes is not 1000, you should also pass num_classes to the model (i.e. model = lvvit_s(img_size=384, num_classes=2))

Williamlizl commented 3 years ago

If the number of classes is not 1000, you should also pass num_classes to the model (i.e. model = lvvit_s(img_size=384, num_classes=2))

It does work, thank you