Closed SirRob1997 closed 2 years ago
Hi,
Actually, this ctc+glat model was not finished, and I should probably delete it.
According to Gu and Kong (2021), we can treat the most probable CTC alignment as the target. I actually have this part implemented in the CTC with DSLP & Mixed Training
model. You can find it here: https://github.com/chenyangh/DSLP/blob/a803b46ffc7aaf4b4117a9bade83b106237446e7/fairseq/models/nat/nat_ctc_sd_ss.py#L384
Then you can replace that in the normal GLAT training.
Ahhh, I see. To verify: so oracle
will be the tensor used as tgt_tokens
? Or is it sufficient to use best_aligns[0]
or best_aligns_pad[0]
?
This would be my current version:
def get_ctc_target_tokens(self, tgt_tokens, prev_output_tokens, logits):
nonpad_positions = tgt_tokens.ne(self.pad)
seq_lens = (nonpad_positions).sum(1)
output_masks = prev_output_tokens.ne(self.pad)
output_length = output_masks.sum(dim=-1)
logits_T = logits.transpose(0, 1).float()
best_aligns = best_alignment(logits_T, tgt_tokens, output_length, seq_lens, self.pad, zero_infinity=True)
best_aligns_pad = torch.tensor([a + [self.pad] * (logits_T.size(0) - len(a)) for a in best_aligns], device=logits_T.device, dtype=tgt_tokens.dtype)
oracle_pos = (best_aligns_pad // 2).clip(max=tgt_tokens.shape[1]-1)
oracle = tgt_tokens.gather(-1, oracle_pos)
return oracle
Yes, oracle
is the token ID of best alignment. best_align
itself is only the positions.
Cool, thanks! Why is // 2
needed in above code?
The algorithm is implemented by the imputer paper: https://github.com/rosinality/imputer-pytorch
Thanks so much for the quick responses, this was really helpful! One last question: What is the point of the --no-empty
flag or rather the masks that are used here:
Are they needed?
It is a good question, Robin.
The //2
operation essentially put the empty tokens to the intermediate argmax
sequence. Since we know that the empty tokens are not informative, we tried to not perform this operation (although it is slightly different from the math).
In our experiments, it may have some small benefits, but not very consistent. Therefore, we choose not to include this trick for simplicity.
I've looked through the code for
nat_ctc_glat.py
and was wondering how the alignment works for the glancing sampling. For the normal CTC without GLAT this is handled byF.ctc_loss
but it seems it's not so straightforward for the GLAT part. I tried to code it up following some of the implementation here as well as in the GLAT repository.For me, it fails for this check
pred_tokens == tgt_tokens
in the GLAT part. Which makes sense as thepred_tokens
will have the length of the upsampled source from CTC but thetgt_tokens
are most likely smaller.Not sure if it fails using your exact code as well but it would make sense to me, what did you change to make this work?