Open NonvolatileMemory opened 3 years ago
Hi, m3yrin,
Thanks for your code~! I changed it in fairseq framework, however I find the speed is slow.
And the main reason is when doing dynamic programing, torch.gather log_probs is time-consuming. I think a simple but effective way to fix it is as following:
# concat_targets: [null, ref1, ref2, ref3] concat_targets = targets.new(targets.size(0), 1).fill_(blank_index) concat_targets = torch.cat([concat_targets, targets], dim=1) repeat_targets = concat_targets.repeat(1, target_sequence_length).view(batch_size, target_sequence_length, target_sequence_length + 1) # prev lprobs: bs x seq-len x vocab size # now lprobs: bs x seq-len x target seq len lprobs = lprobs.gather(dim=-1, index=repeat_targets)#bs seq seq # new targets : 1, 2, 3, 4, 5, 6 targets = torch.tensor(np.repeat(np.arange(target_sequence_length).reshape(1, -1, 1) + 1, batch_size, axis=0)).to(device) # (batch_size, target_sequence_length + 1, lprobs_sequence_length + 1) batch_A = torch.zeros(targets.size(0), targets.size(1) + 1, lprobs.size(1) + 1).to(device) batch_blank_index = torch.full((lprobs.size(0), 1), 0, dtype = torch.long).to(device)
Hi, NonvolatileMemory, Thanks for your comment~! I recently wanted to change the code in fairseq framework, however I found a lot of problems! I wonder if it is convenient to share the code in fairseq framework.
Hi, m3yrin,
Thanks for your code~! I changed it in fairseq framework, however I find the speed is slow.
And the main reason is when doing dynamic programing, torch.gather log_probs is time-consuming. I think a simple but effective way to fix it is as following: