yl4579 / PL-BERT

Phoneme-Level BERT for Enhanced Prosody of Text-to-Speech with Grapheme Predictions
MIT License
217 stars 40 forks source link

Possible bug in masked index generation? #39

Open tekinek opened 10 months ago

tekinek commented 10 months ago

https://github.com/yl4579/PL-BERT/blob/592293aabcb21096eb7f5bffad95a3d38ba4ae6c/dataloader.py#L83

Hi, why the masked_index is extended for 15% of tokens? If I understand correctly, the extention should be placed inside the else statement at line # 80, right?

jav-ed commented 10 months ago

A few days I ago, I was wondering about the very same thing. We would only want the masking to be registered as masked, when the tokens are either masked or modified in other way. The current code is as follows:

phoneme_list = ''.join(phonemes)
masked_index = []
for z in zip(phonemes, input_ids):
    z = list(z)

    words.extend([z[1]] * len(z[0]))
    words.append(self.word_separator)
    labels += z[0] + " "

    if np.random.rand() < self.word_mask_prob:
        if np.random.rand() < self.replace_prob:
            if np.random.rand() < (self.phoneme_mask_prob / self.replace_prob): 
                phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))])  # randomized
            else:
                phoneme += z[0]
        else:
            phoneme += self.token_mask * len(z[0]) # masked

        masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())
    else:
        phoneme += z[0] 

    phoneme += self.token_separator

From what I think, the original goal probably was:

85% of the time: keep original 12.5% of the time: special phoneme mask 1.5% of the time: random (from the available) phoneme mask 1.5% of the time: were to mask the token but kept the original phoneme

--> 86.5% of the time: keep original 12% of the time: special phoneme mask 1.5% of the time: random (from the available) phoneme mask

However, the code as provided above would also try to mask the 1.5% of the times when the token could be masked, but actually kept as the original phoneme (the second mentioning of the 1.5% in the list above). Even though masking 1.5% incorrectly maybe has neglectable impact on the performance, consider the following correction suggestion:

phoneme_list = ''.join(phonemes)
masked_index = []
for z in zip(phonemes, input_ids):
    z = list(z)

    words.extend([z[1]] * len(z[0]))
    words.append(self.word_separator)
    labels += z[0] + " "

    if np.random.rand() < self.word_mask_prob:
        if np.random.rand() < self.replace_prob:
            if np.random.rand() < (self.phoneme_mask_prob / self.replace_prob): 
                phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))])  # randomized

                # added here
                masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())

            else:
                phoneme += z[0]
        else:
            phoneme += self.token_mask * len(z[0]) # masked

            # added here
            masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())

        # removed here
        # masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())

    else:
        phoneme += z[0] 

    phoneme += self.token_separator
tekinek commented 10 months ago

Yes, here is the distribution. For each word (token) in a given sample:

1   85%: keep original
2   15%: 
3        - 80%: whole word masking (e.g. nice —> MMMM)  
4        - 20%:
5              - 50%: random replacement of every phoneme in it (e.g. nice —> csoe)
6              - 50%: keep original

I think, the masked index should be registered for case 3 and 5. But the currect implemetaion covers case 2, which I think is a bug.

yl4579 commented 10 months ago

Thanks for your question. This was intentional. The masked indices are used for loss calculation here: https://github.com/yl4579/PL-BERT/blob/main/train.ipynb (see the if len(_masked_indices) > 0: line), so the masked token also includes unchanged tokens so the model is trained to (1.5% of the times) to reproduce the exact input tokens guided by the loss. If we don't include this, the model will not be able to learn to keep the original tokens if the tokens are unmasked (like during actual time when you use it in TTS fine-tuning).

tekinek commented 10 months ago

Thanks for your clarification. I have trained PL-bert for my language and tried to evaluate it by asking it to predict masked/unmasked tokens and phonemes. In most cases it's prediction make sense, but it fails at predicting "space" between words (which is used as token seperator in this repo). With differenct checkpoints it predicts space as random phoneme, but never the space itself, even when the space is not masked for input.

So I further trained the model in addition to masking the token separator with 15% of chance. Now the model can predict the space.

                phoneme += self.token_separator
                if np.random.rand() < self.phoneme_mask_prob:
                    masked_index.extend((np.arange(len(phoneme) - 1, len(phoneme))).tolist())
jav-ed commented 10 months ago

@tekinek When it comes to the distribution, here are some snippets of my code. Note that the code was not intended for publication, the comments were just for my own personal understanding. Thus, please overlook spelling mistakes and similar mistakes. The main point (probability distribution) should still be extractable from the comments made, inshallah

      # 85% of the time: keep original
      # 12.5% of the time: special phoneme mask
      # 1.5% of the time: random (from the available) phoneme mask
      # 1.5% of the time: were to mask the token but kept the original phoneme

      # --> 86.5% of the time: keep original
      # 12% of the time: special phoneme mask
      # 1.5% of the time: random (from the available) phoneme mask

      # word_mask_prob= 0.15
      # np.random.rand() --> random number between [0;1]
      # for less than 15% of the time or ~ 15% of the time
      if np.random.rand() < self.word_mask_prob:

          # replace_prob=0.2
          # now for less than 20% of the time or ~ 20% of the time
          # for 0.15 * 0.2 = 0.03 of the time
          # for ~ 3% of the time
          if np.random.rand() < self.replace_prob:

              # ------------------ random replacement ------------------ #
              # phoneme_mask_prob=0.1, replace_prob=0.2
              # 0.1/0.2 = 0.5 
              # 0.03 * 0.5 = 0.015 = 1.5% of the time replace the masked phoeneme with a random phoeneme
              if np.random.rand() < (self.phoneme_mask_prob / self.replace_prob): 

                  # np.random.randint(0, len(phoneme_list)) = get an intger between [0, len(phoneme_list)],
                  # note that np.random.randint(a, b) will return actually return a int between a and (b-1) and not b

                  # for the len of current phonemes - choose some random phonemes fromt he available phenemes and add them to the colelciton phoenme string, that is phoneme
                  phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))
                                      ])  # randomized

                  masked_index.extend((np.arange(len(phoneme) - len(z[0]), 
                                         len(phoneme))).tolist())

              # --------------------- take original -------------------- #
              # considered for masking but, kept original:
              # 0.03 *(1 - 0.5) = 0.03 * 0.5 = 0.015 = 1.5%
              else:
                  phoneme += z[0]

          # ------------------- special token masking ------------------ #
          # for ~ 0.15 * (1- 0.2 = 0.8) = 0.15 * 0.8 = 0.12 of the time
          # for ~ 12% of the time special mask token
          else:

              # add masking tikN = "M" to the phoneme string
              phoneme += self.token_mask * len(z[0]) # masked

              masked_index.extend((np.arange(len(phoneme) - len(z[0]), 
                                         len(phoneme))).tolist())

          ## Mofication made here, this line of code, should only be exuted when actually masking occurs and not when the original is taken, for 1.5% of cases this code would be executed, while no masking is actually in place
          # masked_index.extend((np.arange(len(phoneme) - len(z[0]), 
          #                                len(phoneme))).tolist()
          #                     )

      # -------------------- keep original phoenmes -------------------- #
      # for ~ 1 - 0.15 = 0.85 of the time --> do not mask and keep the original phoneme
      else:
          phoneme += z[0] 

      phoneme += self.token_separator

  # count phonemes in the ful lphoneme collection string
  mel_length = len(phoneme)
jav-ed commented 10 months ago

Please feel free to correct me, if I made a mistake in the probability distribution calculations

tekinek commented 10 months ago

@jav-ed your calculation looks correct. However, as @yl4579 clarified, we don't need to change the distribution. But I still suggest you to add the following line right after the last phoneme += self.token_separator in your code, which means mask the token saperator (space), if your langauge uses space between words.

if np.random.rand() < self.phoneme_mask_prob:
                    masked_index.extend((np.arange(len(phoneme) - 1, len(phoneme))).tolist())
yl4579 commented 10 months ago

@tekinek The token separator doesn't need to be predicted because it has a one-to-one correspondence between the grapheme and phoneme (i.e., the space token in the phoneme domain always corresponds to the word separator token in the grapheme domain). Even though the linear projection head fails at predicting this specific token, it won't affect the downstream task because a white space phoneme token means exactly word separator.

jav-ed commented 10 months ago

@tekinek thank you, yes, now I see why there is an advantage in not inserting the mask_index inside the else condition. Just for anybody else, who might not understand it immediately. Basically, @yl4579 explained, that he is tricking the Bert model on purpose. He makes the model believe that something is masked, while it is not masked. Through this implementation, the model is supposed to know that some tokens are already correct and, thus, shall not be replaced.

@yl4579 thank you for your explanation