m3yrin / aligned-cross-entropy

Test implementation of "Aligned Cross Entropy for Non-Autoregressive Machine Translation" https://arxiv.org/abs/2004.01655
MIT License
21 stars 4 forks source link

The training speed is low and we can fix it #1

Open NonvolatileMemory opened 3 years ago

NonvolatileMemory commented 3 years ago

Hi, m3yrin,

Thanks for your code~! I changed it in fairseq framework, however I find the speed is slow.

And the main reason is when doing dynamic programing, torch.gather log_probs is time-consuming. I think a simple but effective way to fix it is as following:

        # concat_targets: [null, ref1, ref2, ref3]
        concat_targets = targets.new(targets.size(0), 1).fill_(blank_index)
        concat_targets = torch.cat([concat_targets, targets], dim=1)
        repeat_targets = concat_targets.repeat(1, target_sequence_length).view(batch_size, target_sequence_length, target_sequence_length + 1)

        # prev lprobs: bs x seq-len x vocab size
        # now  lprobs: bs x seq-len x target seq len
        lprobs = lprobs.gather(dim=-1, index=repeat_targets)#bs seq seq

        # new targets : 1, 2, 3, 4, 5, 6
        targets = torch.tensor(np.repeat(np.arange(target_sequence_length).reshape(1, -1, 1) + 1, batch_size, axis=0)).to(device)

        # (batch_size, target_sequence_length + 1, lprobs_sequence_length + 1)
        batch_A = torch.zeros(targets.size(0), targets.size(1) + 1, lprobs.size(1) + 1).to(device)
        batch_blank_index = torch.full((lprobs.size(0), 1), 0, dtype = torch.long).to(device)
Ariel19977 commented 3 years ago

Hi, m3yrin,

Thanks for your code~! I changed it in fairseq framework, however I find the speed is slow.

And the main reason is when doing dynamic programing, torch.gather log_probs is time-consuming. I think a simple but effective way to fix it is as following:

        # concat_targets: [null, ref1, ref2, ref3]
        concat_targets = targets.new(targets.size(0), 1).fill_(blank_index)
        concat_targets = torch.cat([concat_targets, targets], dim=1)
        repeat_targets = concat_targets.repeat(1, target_sequence_length).view(batch_size, target_sequence_length, target_sequence_length + 1)

        # prev lprobs: bs x seq-len x vocab size
        # now  lprobs: bs x seq-len x target seq len
        lprobs = lprobs.gather(dim=-1, index=repeat_targets)#bs seq seq

        # new targets : 1, 2, 3, 4, 5, 6
        targets = torch.tensor(np.repeat(np.arange(target_sequence_length).reshape(1, -1, 1) + 1, batch_size, axis=0)).to(device)

        # (batch_size, target_sequence_length + 1, lprobs_sequence_length + 1)
        batch_A = torch.zeros(targets.size(0), targets.size(1) + 1, lprobs.size(1) + 1).to(device)
        batch_blank_index = torch.full((lprobs.size(0), 1), 0, dtype = torch.long).to(device)

Hi, NonvolatileMemory, Thanks for your comment~! I recently wanted to change the code in fairseq framework, however I found a lot of problems! I wonder if it is convenient to share the code in fairseq framework.