Natooz / MidiTok

MIDI / symbolic music tokenizers for Deep Learning models 🎶
https://miditok.readthedocs.io/
MIT License
690 stars 82 forks source link

Octuple decode KeyError #206

Open NHeLv1 opened 6 hours ago

NHeLv1 commented 6 hours ago

I trained an autoregressive Transformer using Octuple, but a KerError occurs when it eventually decodes into midi

NHeLv1 commented 6 hours ago

image

Natooz commented 5 hours ago

Hi, Can you share more information on the tokenizer (ideally its configuration, as a tokenizer.json file for example) and the token ids for which this error is occurring?

NHeLv1 commented 4 hours ago

image I haven't used a JSON, just a config, and I don't get an error with REMI. The error that occurs when I call the tokenizer as below: `def generate(self, inputs, mask, max_new_tokens, tokenizer):

    output = inputs.clone()
    for _ in range(max_new_tokens):
        current_seq_length = inputs.size(1)
        # Truncate inputs if it exceeds context_length
        if current_seq_length > self.config.max_len:
            inputs = inputs[:, -self.config.max_len:]
        # we only pass targets on training to calculate loss

        mask = torch.ones_like(inputs.float().mean(dim = -1))
        # print(inputs.shape, mask.shape)
        logits, _ = self(inputs, mask)  

        # for all the batches, get the embeds for last predicted sequence
        logits = logits[:, -1, :] 
        probs = F.softmax(logits, dim=1)  
        probs = probs.squeeze(0).permute(1,0)
        # print(probs.shape)

        # get the probable token based on the input probs
        idx_next = torch.multinomial(probs, num_samples=1) 
        idx_next = idx_next.permute(1,0).unsqueeze(0)
        # print(idx_next.shape)
        inputs = torch.cat([inputs, idx_next], dim=1)
        output = torch.cat([output, idx_next], dim=1)
        # output shape: [batch_size, seq_len, multitrack_num]

    return [tokenizer(out.tolist()) for out in output]`