k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
902 stars 287 forks source link

training with kaldi imported features #238

Open armusc opened 2 years ago

armusc commented 2 years ago

Hi

I'm trying to train model (in this case a bpe 500 conformer ctc-attention encoder-decoder) with kaldi imported features (I'm resorting to this because I cannot extract features fast enough with lhotse, in the absence of a job scheduling across different machines to parallelize in tenths, possibly hundredths of jobs)
so the feaures are the 40-dimensional high resolution from previous Kaldi chain training. I had a training that performed like that image

even though the attention loss does not seem to converge very well, ctc loss seems to converge and the overall loss too. Anyway, I tried decoding with ctc-decoding to have a fast result and with no OOM CUDA to validate the model, but nothing works, meaning that the hyp are all empty so, I printed the per-frame max-valued token-id and I always gets 0 as index, which corresponds to "blk" I even tried decoding the same training corpus (so it cannot be any mismatch issue) and it doesn't work as well, which would seem counterintuitive given those training curves. Extracting features with lhotse over the same corpus, results are overall what I would expect

do you have any tips where I should look into, where the issue(s) might lie? and has anyone tried training-decoding with kaldi pre-computed features? I can provide additional info if you find it useful (well, except for the corpus)

csukuangfj commented 2 years ago

The tensorboard log shows that the model has been trained only for about 500 batches and the model has not converged yet.

so, I printed the per-frame max-valued token-id and I always gets 0 as index, which corresponds to "blk"

If your CTC loss reaches below 0.1, you may get reasonable results with greedy search.

csukuangfj commented 2 years ago

do you have any tips where I should look into, where the issue(s) might lie?

Can you train your model for more batches? If you have a very large dataset, you may want to first use a small subset of it to verify that the training pipleline works and you are able to get a reasonable WER with it.

pzelasko commented 2 years ago

Kaldi features might be larger in magnitude because the audio is scaled to INT16MAX instead of -1 - 1. You might need some sort of normalization unless the model’s first layer is batch norm.

Re Lhotse: it’s possible to leverage distributed feature computation through dask library. See snowfall feature extraction here for an example: https://github.com/k2-fsa/snowfall/blob/911198817edc7b306265f32447ef8a7dc5cfa8f2/egs/librispeech/asr/simple_v1/prepare.py#L26

armusc commented 2 years ago

ok, thanks, first thing I'll resume training to have more batches those actually should correspond to 78 epochs, that's the value I took from the librispeech training, but it looks like each epoch has very few batches, like 5 or 6 with about 17000 frames in total which should correspond to 680 seconds 17000 * 4 / 100 does not make sense since the train manifest represents 320 h I have max_duration = 100 and num_buckets = 10 but I have also set buffer_size = 200 in DynamicBucketingSampler (this is because I could get some batches to be processed before OOM CPU on bigger dataset); so,maybe I'll retrain with the default value

armusc commented 2 years ago

Kaldi features might be larger in magnitude because the audio is scaled to INT16MAX instead of -1 - 1. You might need some sort of normalization unless the model’s first layer is batch norm.

does this significantly impact the convergence speed? yes, I did not normalize the features, as they are stored unnormalized; on the other end, speaker based cmvn was computed, so I could apply this before storing the features and importing

Re Lhotse: it’s possible to leverage distributed feature computation through dask library. See snowfall feature extraction here for an example: https://github.com/k2-fsa/snowfall/blob/911198817edc7b306265f32447ef8a7dc5cfa8f2/egs/librispeech/asr/simple_v1/prepare.py#L26

ok, thanks

pzelasko commented 2 years ago

does this significantly impact the convergence speed? yes, I did not normalize the features, as they are stored unnormalized; on the other end, speaker based cmvn was computed, so I could apply this before storing the features and importing

I expect there to be some impact, in the past I was training transformer ASR with Espresso and the input normalization was the difference between a complete divergence and good convergence.

armusc commented 2 years ago

I am re-training from scratch with cmvn-normalized features (40 dimensional plp), I can see that almost are coefficient are between 1 and -1 now btw, I changed buffer_size to 10000 of DynamicBucketingSampler (I have put it previously to 200) and I can see that many more batches are being processed within the same epoch; for some reason I had only about 17000 frames being processed per epoch before

also, as I explained in another thread, I still have to benefit from using lazy loading and Dynamic bucketing, this corpus has 18GB of features and I can see 31 GB of CPU RAM being consumed constantly (one worker), but I'll focus on that after this training has been validated

pzelasko commented 2 years ago

With a buffer_size of 200 you might have run into accidental premature depletion in dynamic sampler. I'll think if there's a way to reliably detect that and emit a warning.

I'll try to think how to debug the lazy manifest issue on your side. If you can reduce your code to a minimum example that reproduces the problem and post it here as a snippet, that would be very helpful.

videodanchik commented 2 years ago

Hi, @armusc I have a similar setup as you described - 40 MFCC from some previous chain model. 17000 hours (this includes noisy copies + speed and vol. perturbations) with several 8 and 16 kHz (16 kHz were downsampled) datasets combined together. First thing to mention is that I'm doing inverse cosine transform on MFCC to retrieve Fbank back you can use this function, its a small modification of Kaldi's implementation

def compute_idct_matrix(K: int = 40, N: int = 40, cepstral_lifter: float = 22.0) -> torch.Tensor:

    matrix = np.zeros((N, K), dtype=np.float32)
    matrix[:, 0] = 1.0 / N**0.5

    for k in range(1, K):
        for n in range(0, N):
            matrix[n, k] = np.cos(np.pi / N * (n + 0.5) * k)

    matrix[:, 1:] *= (2.0 / N)**0.5

    if cepstral_lifter > 0.0:
        matrix /= compute_lifter_coeffs(cepstral_lifter, K)

    return torch.FloatTensor(matrix.T)

Having this idct matrix you can simply do fbank = torch.matmul(mfcc, idct), where mfcc have a shape (T, 40). Another thing to mention is that I reduced the number of masked frequences in SpecAugment. As the regular icefall setup extracts 80 filter banks instead of 40 I reduced features_mask_size from 27 to 13. Next, I'm doing sliding window mean normalization of the Filter banks extracted from MFCC but window equal to 3 seconds or 300 frames, but no variance normalization. I'm receiving the following TensorBoard logs: image

As you can see I also didn't make it converge to 0.05 - 0.1 interval for ctc_loss, though I got reasonable WERs for my test sets, they are all out of domain and noisy accented telephony speech. There is one thing that still bothers me a bit is speaking about this transition from 80 to 40 fbanks: the receptive field of the first two 2-D ConvolutionSubsampling layers is happens to be twice bigger comparing 40 and 80 fbanks setup. I'm going to play around with the size of the filters and strides in the Conv2d layers of ConvolutionSubsampling.