v-iashin / SpecVQGAN

Source code for "Taming Visually Guided Sound Generation" (Oral at the BMVC 2021)
https://v-iashin.github.io/SpecVQGAN
MIT License
347 stars 40 forks source link

Training conditional transformer #16

Closed radiradev closed 2 years ago

radiradev commented 2 years ago

Hello,

I am trying to understand these lines could you further elaborate what is the procedure of training the transformer here?

`# target includes all sequence elements (no need to handle first one

differently because we are conditioning)

    target = z_indices

    # in the case we do not want to encode condition anyhow (e.g. inputs are features)
    if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
        # make the prediction
        logits, _, _ = self.transformer(z_indices[:, :-1], c)
        # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
        if isinstance(self.transformer, GPTFeatsClass):
            cond_size = c['feature'].size(-1) + c['target'].size(-1)
        else:
            cond_size = c.size(-1)
        logits = logits[:, cond_size-1:]`

Using the features and all of the indices what exactly are we trying to predict? Isn't the target all the z_indices that we are already giving to the transformer? Or are we just predicting the last z_index given the features and the previous z_indices?

v-iashin commented 2 years ago

Or are we just predicting the last z_index given the features and the previous z_indices?

Yes, we are. This is how the transformer is trained or any other network (RNN) that tries to autoregressively generate a sequence of tokens.

During training (and inference), the model predicts the next token given the current one + previous ones.

Schematically:

c - condition token
z - data token
cond_size = 4

input:                cccczzzz
input[:-1]:           cccczzz
Transformer:          |||||||
logits:               ccczzzz
logits[cond_size-1:]:    zzzz
loss:                 (compare)
targets:                 zzzz