JudeLee19 / HMNet-End-to-End-Abstractive-Summarization-for-Meetings

"End-to-End Abstractive Summarization for Meetings" paper - Unofficial PyTorch Implementation
52 stars 13 forks source link

datatype in eval step #10

Open danielkope opened 3 years ago

danielkope commented 3 years ago

I ran into an issue following the readme steps:

[0:00:07.825803][Epoch: 0][Iter: 58][Loss: 8.262327][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:08<00:00, 7.07it/s] [0:00:15.255366][Epoch: 1][Iter: 116][Loss: 8.172640][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.97it/s] [0:00:23.040041][Epoch: 2][Iter: 174][Loss: 7.575418][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.60it/s] [0:00:30.665424][Epoch: 3][Iter: 232][Loss: 6.490891][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.80it/s] [0:00:38.155240][Epoch: 4][Iter: 290][Loss: 6.388737][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.88it/s] [0:00:45.909799][Epoch: 5][Iter: 348][Loss: 6.024454][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.68it/s] [0:00:53.494863][Epoch: 6][Iter: 406][Loss: 6.057027][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.81it/s] [0:01:00.988230][Epoch: 7][Iter: 464][Loss: 5.623709][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.92it/s] [0:01:08.623792][Epoch: 8][Iter: 522][Loss: 6.075109][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.71it/s] [0:01:16.303407][Epoch: 9][Iter: 580][Loss: 5.305166][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.72it/s] [0:01:24.003714][Epoch: 10][Iter: 638][Loss: 4.078269][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.71it/s] [0:01:31.638251][Epoch: 11][Iter: 696][Loss: 5.103086][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.74it/s] [0:01:39.273713][Epoch: 12][Iter: 754][Loss: 4.812594][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.78it/s] [0:01:46.958639][Epoch: 13][Iter: 812][Loss: 4.182279][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.71it/s] [0:01:54.650250][Epoch: 14][Iter: 870][Loss: 5.054197][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.69it/s] [0:02:02.245296][Epoch: 15][Iter: 928][Loss: 5.287556][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.82it/s] [0:02:09.885518][Epoch: 16][Iter: 986][Loss: 4.242270][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.73it/s] [0:02:17.617504][Epoch: 17][Iter: 1044][Loss: 3.547390][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.65it/s] [0:02:25.073161][Epoch: 18][Iter: 1102][Loss: 4.753693][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.97it/s] [0:02:32.696943][Epoch: 19][Iter: 1160][Loss: 4.335735][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.75it/s] [0:02:40.117469][Epoch: 20][Iter: 1218][Loss: 3.922061][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 8.00it/s] ======= Evaluation Start Epoch: 20 ================== 0%| | 0/400 [00:00<?, ?it/s] 0%| | 0/20 [00:02<?, ?it/s] Traceback (most recent call last): File "main.py", line 110, in train_model(args) File "main.py", line 55, in train_model summarization.train() File "/mnt/data/HMNET/HMNet-End-to-End-Abstractive-Summarization-for-Meetings/train.py", line 180, in train eval_path=self.previous_model_path) File "/mnt/data/HMNET/HMNet-End-to-End-Abstractive-Summarization-for-Meetings/predictor.py", line 110, in evaluate role_ids=role_ids, pos_ids=pos_ids) File "/mnt/data/HMNET/HMNet-End-to-End-Abstractive-Summarization-for-Meetings/predictor.py", line 250, in inference [alive_seq.index_select(0, select_indices), RuntimeError: expected scalar type Long but found Float

spencernelsonucla commented 3 years ago

@danielkope try this:

in predictor.py, change:

            # Append last prediction.
            alive_seq = torch.cat(
                [alive_seq.index_select(0, select_indices),
                 topk_ids.view(-1, 1)], -1)

to:

            alive_seq = torch.cat([alive_seq[select_indices.long()],
                                  topk_ids.view(-1, 1)], 
                                  -1)

you may need to do this in other spots, such as lines 294-297. From:

            word_level_memory_beam = word_level_memory_beam.index_select(0, select_indices)
            turn_level_memory_beam = turn_level_memory_beam.index_select(0, select_indices)
            decoder_state.map_batch_fn(
                lambda state, dim: state.index_select(dim, select_indices))

to:

            word_level_memory_beam = word_level_memory_beam[select_indices.long()]
            turn_level_memory_beam = turn_level_memory_beam[select_indices.long()]
            decoder_state.map_batch_fn(
                lambda state, dim: state[select_indices.long()])