TeaPoly / CTC-OptimizedLoss

Computes the MWER (minimum WER) Loss with CTC beam search. Knowledge distillation for CTC loss.
56 stars 10 forks source link

How to use the o1 loss? #4

Open teinhonglo opened 7 months ago

teinhonglo commented 7 months ago

Thanks for sharing the codes. Could you provide an example of the o1 loss? I've combined it with the CTC loss as shown in the following code, but it seems the performance has not improved.

log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)

with torch.backends.cudnn.flags(enabled=False):
                loss = nn.functional.ctc_loss(
                    log_probs,
                    flattened_targets,
                    input_lengths,
                    target_lengths,
                    blank=self.config.pad_token_id,
                    reduction=self.config.ctc_loss_reduction,    # default: sum, use_focal_loss=none
                    zero_infinity=self.config.ctc_zero_infinity, # default: false
)

o1_loss = self.o1_loss(log_probs.transpose(0,1),
                    input_lengths,
                    labels,
                    target_lengths,
)

if self.use_o1_loss:
     o1_loss /= batch_size
     loss = 0.01 * loss + 1. * o1_loss
TeaPoly commented 7 months ago

Thanks for sharing the codes. Could you provide an example of the o1 loss? I've combined it with the CTC loss as shown in the following code, but it seems the performance has not improved.

log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)

with torch.backends.cudnn.flags(enabled=False):
                loss = nn.functional.ctc_loss(
                    log_probs,
                    flattened_targets,
                    input_lengths,
                    target_lengths,
                    blank=self.config.pad_token_id,
                    reduction=self.config.ctc_loss_reduction,    # default: sum, use_focal_loss=none
                    zero_infinity=self.config.ctc_zero_infinity, # default: false
)

o1_loss = self.o1_loss(log_probs.transpose(0,1),
                    input_lengths,
                    labels,
                    target_lengths,
)

if self.use_o1_loss:
     o1_loss /= batch_size
     loss = 0.01 * loss + 1. * o1_loss

The role of beam search in RNNT is far greater than that of CTC decoding, and I agree with your conclusion on this. Currently, this loss function is merely an experimental practice.

teinhonglo commented 7 months ago

Thanks for your response.

I have another question regarding CTC optimization. In your experience, what modification in this repository has been most beneficial for reducing CTC loss?

TeaPoly commented 7 months ago

Thanks for your response.

I have another question regarding CTC optimization. In your experience, what modification in this repository has been most beneficial for reducing CTC loss?

Inter-CTC is very useful for deep NN model. And CTC- CRF is useful for small dataset. https://github.com/thu-spmi/CAT