facebookresearch / gtn_applications

Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"
MIT License
80 stars 7 forks source link

[STC] Question about ASR training. #22

Open LEECHOONGHO opened 2 years ago

LEECHOONGHO commented 2 years ago

Hello, I'm trying to implement ASR model proposed in Star Temporal Classification. But I have some trouble for implementing my first 'word level output ASR model'.

When I use simple word-to-encoder's one hot tensor(E matrix), STC loss ascends. So I made several modifications to solve this problem.

(1). I view x : [T, B, A_L × l_max] to x : [T, B, l_max, A_L], apply F.log_softmax(dim=-1) and view it to original shape x : [T, B, A_L × l_max]

(2). And I apply letter-to-word encoder e_matrix : [A_L × l_max, A_W] by x = x @ e_matrix, and F.log_softmax(torch.exp(x), dim=-1) for STC input.

The reason I applied softmax for A_L is make x to probability of the appearance of alphabets at each location in the word. And log -> e_matrix -> exp is to convert the one hot sum by e_matrix operation into a probability product.

1.Is it right for letter-to-word encoder?

After applying this, STC loss starts from 3.5 and fall to 2.1~2.5 for 15 epoch. But the viterbi decoded(implemented in gtn_applications/criterions/ctc.py) output is always BLANK while checking WER for every predicted output.

Cause CTC loss and word level classification output has the same result, I assume that this is a problem with the properties of CTC training and word-level output ASR.

  1. Is this an ordinary result?
  2. How many epochs are needed to get results other than blank usually?
  3. How many epochs are needed to reach the highest performance usually?

I'm sorry to bother you every time.

vineelpratap commented 2 years ago

Hi,

vineelpratap commented 2 years ago

If it helps, I can try to share a recipe for how we trained Handwriting Recognition System with STC. Let me know.

LEECHOONGHO commented 2 years ago

Hello, @vineelpratap

For CTC, I used torch.nn.functional CTCLoss implemented like huggingface's wav2vec2 for checking my model.
The only difference between word and letter token is dim. But both have some problems for WER, CER(not descending...)

So I'm checking my acoustic model and data loader. And trying to squeeze dataset for selecting greatly reliable ones. So if you share recipe for STC training. It would be a great help to me.

Besides, logsoftmax or sigmoid for A_L wasn't helpful for WER.

Thank you.

labels_mask = labels > 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)

log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)

loss = nn.functional.ctc_loss(
    log_probs,
    flattened_targets,
    feature_lengths,
    target_lengths,
    blank=self.config.blank_idx,
    reduction=self.config.ctc_loss_reduction,
    zero_infinity=self.config.ctc_zero_infinity,
)
vineelpratap commented 2 years ago

I'm checking my acoustic model and data loader.

Sure.

If CTC is not working, it seems to be that the problem is somewhere else. Please also see https://github.com/facebookresearch/gtn_applications/blob/main/criterions/ctc.py#L110-L120 for reference call to PyTorch CTC.

jasonppy commented 2 years ago

If it helps, I can try to share a recipe for how we trained Handwriting Recognition System with STC. Let me know.

Hi Vineel, is it possible to share the recipe?

jasonppy commented 1 year ago

Hi @LEECHOONGHO,

If you are still interested in using STC, please get in touch!

Thanks,