Natooz / MidiTok

MIDI / symbolic music tokenizers for Deep Learning models 🎶
https://miditok.readthedocs.io/
MIT License
698 stars 84 forks source link

Octuple decode KeyError #206

Open NHeLv1 opened 1 week ago

NHeLv1 commented 1 week ago

I trained an autoregressive Transformer using Octuple, but a KerError occurs when it eventually decodes into midi

NHeLv1 commented 1 week ago

image

Natooz commented 1 week ago

Hi, Can you share more information on the tokenizer (ideally its configuration, as a tokenizer.json file for example) and the token ids for which this error is occurring?

NHeLv1 commented 1 week ago

image I haven't used a JSON, just a config, and I don't get an error with REMI. The error that occurs when I call the tokenizer as below: `def generate(self, inputs, mask, max_new_tokens, tokenizer):

    output = inputs.clone()
    for _ in range(max_new_tokens):
        current_seq_length = inputs.size(1)
        # Truncate inputs if it exceeds context_length
        if current_seq_length > self.config.max_len:
            inputs = inputs[:, -self.config.max_len:]
        # we only pass targets on training to calculate loss

        mask = torch.ones_like(inputs.float().mean(dim = -1))
        # print(inputs.shape, mask.shape)
        logits, _ = self(inputs, mask)  

        # for all the batches, get the embeds for last predicted sequence
        logits = logits[:, -1, :] 
        probs = F.softmax(logits, dim=1)  
        probs = probs.squeeze(0).permute(1,0)
        # print(probs.shape)

        # get the probable token based on the input probs
        idx_next = torch.multinomial(probs, num_samples=1) 
        idx_next = idx_next.permute(1,0).unsqueeze(0)
        # print(idx_next.shape)
        inputs = torch.cat([inputs, idx_next], dim=1)
        output = torch.cat([output, idx_next], dim=1)
        # output shape: [batch_size, seq_len, multitrack_num]

    return [tokenizer(out.tolist()) for out in output]`
NHeLv1 commented 1 week ago

I think I've understood something.When using the octuple tokenizer,I didn't retrieve all the ids for each of its tracks,and the size of the embeddings my model uses exceeded the id range,which is why a KeyError occurred.So,where can I obtain the octuple's ids?

NHeLv1 commented 1 week ago

I change my Transformer into multi vocab, but I got a new problem image and Here is my code

` class NaiveGPT(nn.Module): def init(self, config: GPT1Config, multi_vocab_lens=None): ...

def forward(self, inputs, mask, targets = None):
    bs, sl, track_num = inputs.shape
    inputs = inputs.permute(2, 0, 1)
    embedded = [embed(token) for embed, token in zip(self.wte, inputs)]
    concat_embed = torch.cat(embedded, dim=-1)

    # logits = self.wte(inputs) # dim -> batch_size, sequence_length, d_model
    logits = self.projection_in(concat_embed)
    logits = self.wpe(logits)
    for block in self.blocks:
        logits = block(logits, mask)
    # logits = self.linear1(logits)
    logits = self.projection_out(logits)

    logits = logits.view(bs, sl, -1, track_num).permute(3, 0, 1, 2)

    out_tokens = [linear(logit) for linear, logit in zip(self.linear2vocab, logits)]

    # out_tokens = torch.cat(out_tokens, dim = -1) 

    loss_ce = None
    if targets != None:
        targets = targets.permute(2, 0, 1)

        for i, (tgt_track, out_track) in enumerate(zip(targets, out_tokens)):

            batch_size, sequence_length, track_vocab_size = out_track.shape
            out_track = out_track.view(batch_size * sequence_length, track_vocab_size)
            tgt_track = tgt_track.view(batch_size * sequence_length)
            # print(out_track.shape)
            # print(tgt_track.shape)

            loss_i = F.cross_entropy(out_track, tgt_track, ignore_index=-100)
            if i == 0:
                loss_ce = loss_i
            else:
                loss_ce += loss_i
        # loss = F.cross_entropy(out_tokens, targets, ignore_index=-100)
    # raise ValueError("coding")
    return out_tokens, loss_ce

def generate(self, inputs, mask, max_new_tokens, tokenizer):

    output = inputs.clone()
    for _ in range(max_new_tokens):
        current_seq_length = inputs.size(1)
        # Truncate inputs if it exceeds context_length
        if current_seq_length > self.config.max_len:
            inputs = inputs[:, -self.config.max_len:]
        # we only pass targets on training to calculate loss

        mask = torch.ones_like(inputs.float().mean(dim = -1))
        # print(inputs.shape, mask.shape)
        logits, _ = self(inputs, mask)  

        # multi track modified
        # print(inputs.shape)
        idx_next = []
        for i, track in enumerate(logits):
            logits = track[:, -1, :] 
            # print(logits.shape)
            probs = F.softmax(logits, dim=1)
            track_idx_next = torch.multinomial(probs, num_samples=1)
            idx_next.append(track_idx_next)
            # print(track_idx_next.shape)
        idx_next = torch.cat(idx_next, dim=-1)
        idx_next = idx_next.unsqueeze(1)
        # print(idx_next.shape)

        inputs = torch.cat([inputs, idx_next], dim=1)
        output = torch.cat([output, idx_next], dim=1)
        # print(inputs.shape)

        # raise ValueError('coding')

    for out in output:
        print(out.tolist())
    return [tokenizer(out.tolist()) for out in output]

`

Natooz commented 6 days ago

(Apologies for the late reply)

Indeed Octuple is "multi-vocabulary", i.e. the model needs multiple embedding layers, and multiple output modules (one per token type). The last error seems to occur because the model predicted a non-valid time signature token (probably "something"_None within this sub vocabulary). I don't how how is the model trained exactly so I can't tell if it is expected for the model to predict any token with a "None" value part (EOS/BOS?). In the case of EOS, I suggest to preprocess the token ids and cut the sequence when one is met.

NHeLv1 commented 6 days ago

Thx, It is indeed caused by EOS/BOS as you said, I ignored it when generating, but it ends up generating music poorly, and I think Octuple may not have been designed for generation tasks

Natooz commented 6 days ago

I indeed do not recommend Octuple for generation tasks, for several reasons. First having multiple vocabularies/input modules/output modules increases the complexity of the model, and of the training procedure that now musts compute several losses, this must be implemented and tested and might not be compatible with popular libraries such assign HF transformers. Secondly sampling multiple tokens from different distributions that are not computed with any of them conditioning the others leads to higher variance and potentially incoherent results. That's less of an issue for bidirectional transformers (i.e. no attention masking) as the prediction error doesn't accumulate over autoregressive generation steps. Third, BPE/Unigram is to me a much better alternative to reduce the sequence length, almost matching octuple's reduction ratio (can surpass it depending on the data/vocab size), is compatible out of the box with all models/libraries and allows to benefit from increases "semantic"/information density represented by the aggregated tokens in the vocabulary.