chenyangh / DSLP

Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation
MIT License
43 stars 5 forks source link

Does your implementation for CTC + GLAT work? #5

Closed SirRob1997 closed 2 years ago

SirRob1997 commented 2 years ago

I've looked through the code for nat_ctc_glat.py and was wondering how the alignment works for the glancing sampling. For the normal CTC without GLAT this is handled by F.ctc_loss but it seems it's not so straightforward for the GLAT part. I tried to code it up following some of the implementation here as well as in the GLAT repository.

For me, it fails for this check pred_tokens == tgt_tokens in the GLAT part. Which makes sense as the pred_tokens will have the length of the upsampled source from CTC but the tgt_tokens are most likely smaller.

Not sure if it fails using your exact code as well but it would make sense to me, what did you change to make this work?

chenyangh commented 2 years ago

Hi,

Actually, this ctc+glat model was not finished, and I should probably delete it.

According to Gu and Kong (2021), we can treat the most probable CTC alignment as the target. I actually have this part implemented in the CTC with DSLP & Mixed Training model. You can find it here: https://github.com/chenyangh/DSLP/blob/a803b46ffc7aaf4b4117a9bade83b106237446e7/fairseq/models/nat/nat_ctc_sd_ss.py#L384

Then you can replace that in the normal GLAT training.

SirRob1997 commented 2 years ago

Ahhh, I see. To verify: so oracle will be the tensor used as tgt_tokens? Or is it sufficient to use best_aligns[0] or best_aligns_pad[0]?

This would be my current version:

def get_ctc_target_tokens(self, tgt_tokens, prev_output_tokens, logits):
        nonpad_positions = tgt_tokens.ne(self.pad)
        seq_lens = (nonpad_positions).sum(1)
        output_masks = prev_output_tokens.ne(self.pad)
        output_length = output_masks.sum(dim=-1)
        logits_T = logits.transpose(0, 1).float()
        best_aligns = best_alignment(logits_T, tgt_tokens, output_length, seq_lens, self.pad, zero_infinity=True)
        best_aligns_pad = torch.tensor([a + [self.pad] * (logits_T.size(0) - len(a)) for a in best_aligns], device=logits_T.device, dtype=tgt_tokens.dtype)
        oracle_pos = (best_aligns_pad // 2).clip(max=tgt_tokens.shape[1]-1)
        oracle = tgt_tokens.gather(-1, oracle_pos)
        return oracle
chenyangh commented 2 years ago

Yes, oracle is the token ID of best alignment. best_align itself is only the positions.

SirRob1997 commented 2 years ago

Cool, thanks! Why is // 2 needed in above code?

chenyangh commented 2 years ago

The algorithm is implemented by the imputer paper: https://github.com/rosinality/imputer-pytorch

SirRob1997 commented 2 years ago

Thanks so much for the quick responses, this was really helpful! One last question: What is the point of the --no-empty flag or rather the masks that are used here:

https://github.com/chenyangh/DSLP/blob/a803b46ffc7aaf4b4117a9bade83b106237446e7/fairseq/models/nat/nat_ctc_sd_ss.py#L389

Are they needed?

chenyangh commented 2 years ago

It is a good question, Robin.

The //2 operation essentially put the empty tokens to the intermediate argmax sequence. Since we know that the empty tokens are not informative, we tried to not perform this operation (although it is slightly different from the math).

In our experiments, it may have some small benefits, but not very consistent. Therefore, we choose not to include this trick for simplicity.