aimagelab / meshed-memory-transformer

Meshed-Memory Transformer for Image Captioning. CVPR 2020
BSD 3-Clause "New" or "Revised" License
518 stars 136 forks source link

RuntimeError: gather(): Expected dtype int64 for index, in beam_search/beam_search.py, line 26, in fn #82

Closed linhuixiao closed 1 year ago

linhuixiao commented 2 years ago

Meshed-Memory Transformer Evaluation Evaluation: 0%|
Evaluation: 0%| | 0/500 [00:00<?, ?it/s]

Traceback (most recent call last): File "test.py", line 78, in scores = predict_captions(model, dict_dataloader_test, text_field) File "test.py", line 26, in predictcaptions out, = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/captioning_model.py", line 70, in beam_search return bs.apply(visual, out_size, return_probs, kwargs) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 71, in apply visual, outputs = self.iter(t, visual, outputs, return_probs, kwargs) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 121, in iter self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/containers.py", line 30, in apply_to_states self._buffers[name] = fn(self._buffers[name]) File "/home/lhxiao/pcl_experiment_202203/meshed-memory-transformer/models/beam_search/beam_search.py", line 26, in fn s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, RuntimeError: gather(): Expected dtype int64 for index

linhuixiao commented 2 years ago

I have solved this bug:

this is a bug, please fix the code in models/beam_search.py line 118:

        # selected_beam = selected_idx / candidate_logprob.shape[-1]
        selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode="floor")