kdexd / virtex

[CVPR 2021] VirTex: Learning Visual Representations from Textual Annotations
http://kdexd.xyz/virtex
MIT License
556 stars 61 forks source link

eval_captioning.py - RuntimeError: gather_out_cuda(): Expected dtype int64 for index #20

Closed yonatanbitton closed 3 years ago

yonatanbitton commented 3 years ago

Hello

I'm trying to perform inference on new images

I've followed the setup instructions here

This is the CMD i'm running:

python scripts/eval_captioning.py \
    --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
    --checkpoint-path /tmp/.torch/virtex_cache/bicaptioning_R_50_L1_H2048.pth \
    --data-root /path/to/images_dir \
    --output /path/to/save/predictions.json \
    --num-gpus-per-machine 1 \
    --cpu-workers 4

This is my stacktrace:

** fvcore version of PathManager will be deprecated soon. **
** Please migrate to the version in iopath repo. **
https://github.com/facebookresearch/iopath 

2020-12-16 16:34:21.685 | INFO     | virtex.utils.checkpointing:load:156 - Rank 0: Loading checkpoint from /tmp/.torch/virtex_cache/bicaptioning_R_50_L1_H2048.pth
2020-12-16 16:34:22.134 | INFO     | virtex.utils.checkpointing:load:166 - Rank 0: Loading model from /tmp/.torch/virtex_cache/bicaptioning_R_50_L1_H2048.pth
Traceback (most recent call last):
  File "scripts/eval_captioning.py", line 114, in <module>
    main(_A)
  File "scripts/eval_captioning.py", line 78, in main
    output_dict = model(val_batch)
  File "/yonatan/virtex/virtex_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/yonatan/virtex/virtex/models/captioning.py", line 176, in forward
    start_predictions, beam_search_step
  File "/yonatan/virtex/virtex/utils/beam_search.py", line 250, in search
    predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
RuntimeError: gather_out_cuda(): Expected dtype int64 for index

This is my system (Linux):

(virtex_env) (base) [p virtex]$ python
Python 3.6.10 |Anaconda, Inc.| (default, Jan  7 2020, 21:14:29) 
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'1.7.0'

This is my datasets structure:

(virtex_env) (base) [p virtex]$ ls datasets/
coco  vocab
(virtex_env) (base) [p virtex]$ ls datasets/vocab/
coco_10k.model  coco_10k.vocab
(virtex_env) (base) [p virtex]$ ls datasets/coco/
annotations  serialized_train.lmdb  serialized_train.lmdb-lock  serialized_val.lmdb  serialized_val.lmdb-lock  train2017  val2017

What am I missing?

Thanks

kdexd commented 3 years ago

Hi @yonatanbitton, I found this issue yesterday — it is caused by the new PyTorch version (1.4 —> 1.7). I pushed a fix in 648fd67. Feel free to re-open if this does not work, thanks!