k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Difference Between Forward & Backward Score Exceeds 1.0 #944

Closed teowenshen closed 2 years ago

teowenshen commented 2 years ago

Hi,

I am using K2's FSA functionality to intersect CTC graphs (decoding_graph : FsaVec) with the neural network output (nnet_output : torch.Tensor). It is not an ASR project, but still within the field of the classification of time-series data.

When the neural network assigns an overly high probability to an output symbol not in that sample's CTC graph, I understand that intersection between dense_fsa_vec = k2.DenseFsaVec(nnet_output) and decoding_graph will return an overall low probability. However, somehow this also causes the following warning to occur:

[W] /github/home/miniconda3/conda-bld/k2_1639014631542/work/k2/csrc/intersect_dense.cu:890:k2::Array1 k2::MultiGraphDenseIntersect::GetScoreCutoffs() The difference between forward score and backward score exceeds 1.0, the value is : 2.250000 [W] /github/home/miniconda3/conda-bld/k2_1639014631542/work/k2/csrc/intersect_dense.cu:890:k2::Array1 k2::MultiGraphDenseIntersect::GetScoreCutoffs() The difference between forward score and backward score exceeds 1.0, the value is : 14.000000 [W] /github/home/miniconda3/conda-bld/k2_1639014631542/work/k2/csrc/intersect_dense.cu:890:k2::Array1 k2::MultiGraphDenseIntersect::GetScoreCutoffs() The difference between forward score and backward score exceeds 1.0, the value is : 14.000000 [W] /github/home/miniconda3/conda-bld/k2_1639014631542/work/k2/csrc/intersect_dense.cu:890:k2::Array1 k2::MultiGraphDenseIntersect::GetScoreCutoffs() The difference between forward score and backward score exceeds 1.0, the value is : 2.250000 [W] /github/home/miniconda3/conda-bld/k2_1639014631542/work/k2/csrc/intersect_dense.cu:890:k2::Array1 k2::MultiGraphDenseIntersect::GetScoreCutoffs() The difference between forward score and backward score exceeds 1.0, the value is : 7.000000 [W] /github/home/miniconda3/conda-bld/k2_1639014631542/work/k2/csrc/intersect_dense.cu:890:k2::Array1 k2::MultiGraphDenseIntersect::GetScoreCutoffs() The difference between forward score and backward score exceeds 1.0, the value is : 6.000000 [W] /github/home/miniconda3/conda-bld/k2_1639014631542/work/k2/csrc/intersect_dense.cu:890:k2::Array1 k2::MultiGraphDenseIntersect::GetScoreCutoffs() The difference between forward score and backward score exceeds 1.0, the value is : 3.250000

Do you have any idea why this is happening? My codes are as following, in case I did something wrong.

I am using a simple 1-layered TDNN as below:-

class TDNN_Simple(nn.Module):
    def __init__(self, num_features: int, num_classes: int):
        super(TDNN_Simple, self).__init__()
        self.tdnn = nn.Conv1d(
            in_channels=num_features,
            out_channels=num_classes,
            kernel_size=5,
            dilation=1
        )
        self.layernorm = nn.LayerNorm(num_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Input x: (B, F, C)
        # b = x.clone()
        x = self.layernorm(x)
        x = x.permute(0, 2, 1)
        x = self.tdnn(x)
        x = x.permute(0, 2, 1)
        x = nn.functional.log_softmax(x, dim=-1)
        return x

Training code:

    dense_fsa_vec = k2.DenseFsaVec(
        nnet_output,
        supervision_segments, 
        allow_truncate=3  # because TDNN without padding consumes some frames
    )

    loss = k2.ctc_loss(
        decoding_graph=decoding_graph,
        dense_fsa_vec=dense_fsa_vec,
        output_beam=params.beam_size, # 10
        reduction=params.reduction, # 10
        use_double_scores=params.use_double_scores # True
    )

where supervision_segments is as below, because my "recordings" are of consistently equal lengths.

    supervision_segments = torch.cat(
        [torch.arange(0, batch_size, 1, dtype=torch.int32).unsqueeze_(1),
         torch.tensor([[0, frame_num]] * batch_size, dtype=torch.int32)],
        dim=1
    )
danpovey commented 2 years ago

These kinds of things would normally be due to numerical roundoff on either long sequences or where the scores have large numerical values.