Open danielkope opened 3 years ago
@danielkope try this:
in predictor.py, change:
# Append last prediction.
alive_seq = torch.cat(
[alive_seq.index_select(0, select_indices),
topk_ids.view(-1, 1)], -1)
to:
alive_seq = torch.cat([alive_seq[select_indices.long()],
topk_ids.view(-1, 1)],
-1)
you may need to do this in other spots, such as lines 294-297. From:
word_level_memory_beam = word_level_memory_beam.index_select(0, select_indices)
turn_level_memory_beam = turn_level_memory_beam.index_select(0, select_indices)
decoder_state.map_batch_fn(
lambda state, dim: state.index_select(dim, select_indices))
to:
word_level_memory_beam = word_level_memory_beam[select_indices.long()]
turn_level_memory_beam = turn_level_memory_beam[select_indices.long()]
decoder_state.map_batch_fn(
lambda state, dim: state[select_indices.long()])
I ran into an issue following the readme steps:
[0:00:07.825803][Epoch: 0][Iter: 58][Loss: 8.262327][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:08<00:00, 7.07it/s] [0:00:15.255366][Epoch: 1][Iter: 116][Loss: 8.172640][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.97it/s] [0:00:23.040041][Epoch: 2][Iter: 174][Loss: 7.575418][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.60it/s] [0:00:30.665424][Epoch: 3][Iter: 232][Loss: 6.490891][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.80it/s] [0:00:38.155240][Epoch: 4][Iter: 290][Loss: 6.388737][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.88it/s] [0:00:45.909799][Epoch: 5][Iter: 348][Loss: 6.024454][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.68it/s] [0:00:53.494863][Epoch: 6][Iter: 406][Loss: 6.057027][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.81it/s] [0:01:00.988230][Epoch: 7][Iter: 464][Loss: 5.623709][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.92it/s] [0:01:08.623792][Epoch: 8][Iter: 522][Loss: 6.075109][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.71it/s] [0:01:16.303407][Epoch: 9][Iter: 580][Loss: 5.305166][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.72it/s] [0:01:24.003714][Epoch: 10][Iter: 638][Loss: 4.078269][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.71it/s] [0:01:31.638251][Epoch: 11][Iter: 696][Loss: 5.103086][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.74it/s] [0:01:39.273713][Epoch: 12][Iter: 754][Loss: 4.812594][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.78it/s] [0:01:46.958639][Epoch: 13][Iter: 812][Loss: 4.182279][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.71it/s] [0:01:54.650250][Epoch: 14][Iter: 870][Loss: 5.054197][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.69it/s] [0:02:02.245296][Epoch: 15][Iter: 928][Loss: 5.287556][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.82it/s] [0:02:09.885518][Epoch: 16][Iter: 986][Loss: 4.242270][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.73it/s] [0:02:17.617504][Epoch: 17][Iter: 1044][Loss: 3.547390][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.65it/s] [0:02:25.073161][Epoch: 18][Iter: 1102][Loss: 4.753693][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.97it/s] [0:02:32.696943][Epoch: 19][Iter: 1160][Loss: 4.335735][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 7.75it/s] [0:02:40.117469][Epoch: 20][Iter: 1218][Loss: 3.922061][lr: 0.000500]: 100%|██████████████████████████████████████████████████████████████████████████| 58/58 [00:07<00:00, 8.00it/s] ======= Evaluation Start Epoch: 20 ================== 0%| | 0/400 [00:00<?, ?it/s] 0%| | 0/20 [00:02<?, ?it/s] Traceback (most recent call last): File "main.py", line 110, in
train_model(args)
File "main.py", line 55, in train_model
summarization.train()
File "/mnt/data/HMNET/HMNet-End-to-End-Abstractive-Summarization-for-Meetings/train.py", line 180, in train
eval_path=self.previous_model_path)
File "/mnt/data/HMNET/HMNet-End-to-End-Abstractive-Summarization-for-Meetings/predictor.py", line 110, in evaluate
role_ids=role_ids, pos_ids=pos_ids)
File "/mnt/data/HMNET/HMNet-End-to-End-Abstractive-Summarization-for-Meetings/predictor.py", line 250, in inference
[alive_seq.index_select(0, select_indices),
RuntimeError: expected scalar type Long but found Float