k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.08k stars 211 forks source link

A warning is always printed when there is an extreme reduction in the encoder output frame rate. #1271

Closed TeaPoly closed 4 months ago

TeaPoly commented 4 months ago

The reason for I using extreme encoder output frame rate due to the paper Extreme Encoder Output Frame Rate Reduction: Improving Computational Latencies of Large End-to-End Models, which show that 320~5120ms output frame rate with HAT loss has comparable result by Alignment-Length Synchronous Decoding.

When I use encoder output frame rate with total 2560ms and prune_range is 5, there are some warning like this:

Warning: get_rnnt_prune_ranges - got s_range=5 for boundaries S=tensor([17, 18, 14, 13, 16, 20, 16, 18, 20, 18, 11, 19, 18, 15, 16, 15, 19, 16,
        15, 15, 15, 18, 16, 11, 15, 15, 17, 19, 16, 17, 17, 16, 15, 17, 16, 15,
        17, 11, 15, 16, 17], device='cuda:3'), T=tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:3'). Adjusting to 11

When I setting prune_range to 20, there are other warning like this:

Warning: get_rnnt_prune_ranges - got s_range=20 for boundaries S=19. Adjusting to 20

What should I do to handle this problem?

This is how I use pruned rnn-t loss.

        device = x.device
        y = k2.RaggedTensor(padding_tesor_to_list(y, y_lens)).to(device)

        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.num_axes == 2, y.num_axes

        assert x.size(0) == x_lens.size(0) == y.dim0

        encoder_out = x
        assert torch.all(x_lens > 0)

        # Now for the decoder, i.e., the prediction network
        row_splits = y.shape.row_splits(1)
        y_lens = row_splits[1:] - row_splits[:-1]

        blank_id = decoder.blank_id
        sos_y = add_sos(y, sos_id=blank_id)

        # sos_y_padded: [B, S + 1], start with SOS.
        sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

        # decoder_out: [B, S + 1, decoder_dim]
        decoder_out = decoder(sos_y_padded)

        # Note: y does not start with SOS
        # y_padded : [B, S]
        y_padded = y.pad(mode="constant", padding_value=0)

        y_padded = y_padded.to(torch.int64)
        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
        boundary[:, 2] = y_lens
        boundary[:, 3] = x_lens

        lm = self.simple_lm_proj(decoder_out)
        am = self.simple_am_proj(encoder_out)

        with torch.cuda.amp.autocast(enabled=False):
            simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
                lm=lm.float(),
                am=am.float(),
                symbols=y_padded,
                termination_symbol=blank_id,
                lm_only_scale=self.lm_scale,
                am_only_scale=self.am_scale,
                boundary=boundary,
                reduction=self.reduction,
                return_grad=True,
                delay_penalty=delay_penalty,
            )

        # ranges : [B, T, prune_range]
        ranges = k2.get_rnnt_prune_ranges(
            px_grad=px_grad,
            py_grad=py_grad,
            boundary=boundary,
            s_range=self.prune_range,
        )

        # am_pruned : [B, T, prune_range, encoder_dim]
        # lm_pruned : [B, T, prune_range, decoder_dim]
        am_pruned, lm_pruned = k2.do_rnnt_pruning(
                am=joiner.encoder_proj(encoder_out),
                lm=joiner.decoder_proj(decoder_out),
                ranges=ranges,
            )

        # logits : [B, T, prune_range, vocab_size]

        # project_input=False since we applied the decoder's input projections
        # prior to do_rnnt_pruning (this is an optimization for speed).
        logits = joiner(am_pruned, lm_pruned, project_input=False)

        with torch.cuda.amp.autocast(enabled=False):
            pruned_loss = k2.rnnt_loss_pruned(
                logits=logits.float(),
                symbols=y_padded,
                ranges=ranges,
                termination_symbol=blank_id,
                boundary=boundary,
                reduction=self.reduction,
                delay_penalty=delay_penalty,
            )

        return (simple_loss, pruned_loss)
pkufool commented 4 months ago

If the printed warnings are not a lot, you can just ignore it. The warnings just tell you that the prune_range is too small to cover all the labels (transcript tokens) within t steps, it will automaticaly increase the prune_range . If the logs are a lot you can try using a larger prune_range.

TeaPoly commented 4 months ago

If the printed warnings are not a lot, you can just ignore it. The warnings just tell you that the prune_range is too small to cover all the labels (transcript tokens) within t steps, it will automaticaly increase the prune_range . If the logs are a lot you can try using a larger prune_range.

Thanks for your reply.