facebookresearch / gtn_applications

Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"
MIT License
80 stars 7 forks source link

[STC] STC loss ascends while training #21

Closed LEECHOONGHO closed 2 years ago

LEECHOONGHO commented 2 years ago

Hello, I'm training ASR model with STC Loss and letter-to-word encoder like below. But when I progress training, STC Loss ascended and became 'Inf' after 12000 step.

Is there any miss in my implementation? Any help would be appreciated. Thank you.

training args:

blank_idx=0, p0=0.05, plast=0.15, thalf=16000

self.criterion = STC( blank_idx=self.cfg.blank_idx, p0=self.cfg.p0, plast=self.cfg.plast, thalf=self.cfg.thalf, reduction="mean" )

model_output : Tensor[batch_size, max_frame_length, n_letter_symbols*max_word_length]

self.l2w_matrix : Tensor[n_letter_symbols*max_word_length, n_word_symbols]

word_level_output = model_output @ self.l2w_matrix

word_level_output : Tensor[batch_size, max_frame_length, n_word_symbols]

word_level_output = F.log_softmax(word_level_output.transpose(1, 0), dim=-1)

loss = self.criterion(word_level_output, word_labels)


![stcloss](https://user-images.githubusercontent.com/44384060/159213516-d5ccca1c-31ac-48c3-9afc-70fc093d0a69.PNG)
vineelpratap commented 2 years ago

Hi, Could you give a few details about the dataset you are using and how partial labels are created. Also, Have you tried CTC with the same training setup?

As a sanity check, you can output Tensor[batch_size, max_frame_length, n_word_symbols] from the model directly. This will make sure there are no errors with l2w_matrix creation.

Just to make sure have you made sure word_labels are 1-indexed to account for blank symbol ?

LEECHOONGHO commented 2 years ago

Hello. @vineelpratap I'm using mixed korean asr dataset(num_audio=3180000) from aihub. How I get l2w_matrix is like below

Now I'm trying to train with ctc(gtn_application). Thank you for your advice.

def get_l2w_matrix(self):
    # letters = {'BLANK':0, 'PAD':1, 'letter1':2, .....}
    # morphs = {'BLANK':0, 'word1':1, .....}

    with open(self.config.letter_dict_path, 'r', encoding='utf8') as j:
        letters = json.load(j)
    with open(self.config.word_dict_path, 'r', encoding='utf8') as j:
        words= json.load(j)

    E_matrix = torch.zeros(self.config.max_word_length, len(letters ), len(words)).bool()

    # set one hot vector
    for word, word_idx in words.items():
        if word == 'BLANK':
            E_matrix[0, letters['BLANK'], word_idx] = True
        else:
            for letter_idx, letter in enumerate(word2letter(word)):
                E_matrix[letter_idx, letters[letter], word_idx] = True

    # padding
    padding_location = ~torch.any(E_matrix, 1)
    E_matrix[:, letters['PAD'], :][padding_location] = True

    torch.save(E_matrix, self.config.e_matrix_path)

    return E_matrix.half().view(-1, E_matrix.shape[-1])
LEECHOONGHO commented 2 years ago

When I changed thalf 16000 -> 8000, and model to output word level directly, STC loss increases slightly after 8k step.

More details about by model is,

  1. I input normalized wav directly to wav2vec2.0 like model(WavLM)
  2. change total stride of wav feature extractor 320 -> 1280 (80ms per output frame)
  3. number of word per second in audio data is 0.5~6.5 (much unvoiced segment in audio)

loss2