JamesOwers / midi_degradation_toolkit

A toolkit for generating datasets of midi files which have been degraded to be 'un-musical'.
MIT License
38 stars 5 forks source link

Datasets and loaders #53

Closed JamesOwers closed 5 years ago

JamesOwers commented 5 years ago

First draft here. Rerun make_dataset.py and you'll get new files {train,valid,test}_cmd_corpus.csv. These files are read by the dataset in mdtk.pytorch_datasets to produce instances of a given index. For example:

from mdtk.pytorch_datasets import CommandDataset, CommandVocab
vocab = CommandVocab()
train_dataset = CommandDataset('./acme/train_cmd_corpus.csv', vocab, 128)
... Loading Dataset: 18005it [00:00, 144793.12it/s]

Each data item is a dictionary containing the degraded commands as integers (in a list), likewise for the clean commands, and the degradation label (I'm assuming that deg_label 0 will always mean 'not degraded', so we don't need a is_degraded flag).

train_dataset[0]
... {'deg_commands': [3, 49, 50, 1, 177, 1, 47, 64, 69, 1, 71, 1, 175, 1, 178, 45, 1, 173, 192, 1, 197, 1, 40, 59, 62, 64, 69, 73, 1, 187, 1, 192, 197, 201, 1, 64, 1, 71, 1, 168, 1, 47, 1, 192, 1, 199, 1, 175, 64, 69, 1, 40, 1, 192, 1, 168, 197, 1, 50, 64, 68, 68, 71, 76, 1, 178, 1, 199, 49, 1, 199, 1, 190, 196, 204, 1, 192, 1, 177, 196, 62, 1, 47, 64, 66, 69, 69, 74, 1, 175, 1, 194, 45, 1, 190, 197, 202, 1, 173, 1, 192, 197, 1, 40, 59, 64, 68, 68, 71, 76, 1, 187, 1, 168, 64, 1, 47, 1, 192, 1, 175, 1, 40, 64, 1, 168, 1, 192],
'clean_commands': [3, 49, 50, 1, 177, 1, 47, 64, 69, 1, 175, 1, 178, 45, 1, 173, 192, 1, 197, 1, 40, 59, 62, 64, 69, 73, 1, 187, 1, 192, 197, 201, 1, 64, 1, 71, 1, 168, 1, 47, 1, 192, 1, 199, 1, 175, 64, 69, 1, 40, 1, 192, 1, 168, 197, 1, 50, 64, 68, 68, 71, 76, 1, 178, 1, 199, 49, 1, 190, 196, 204, 1, 192, 1, 177, 196, 62, 1, 47, 64, 66, 69, 69, 74, 1, 175, 1, 194, 45, 1, 190, 197, 202, 1, 173, 1, 192, 197, 1, 40, 59, 64, 68, 68, 71, 71, 76, 1, 187, 1, 168, 64, 1, 47, 1, 192, 1, 175, 1, 40, 64, 1, 168, 1, 192, 50, 68, 1],
'deg_label': 2}

The vocabulary can be used to get the tokens back:

[vocab.itos[idx] for idx in train_dataset[0]['deg_commands']]
... ['<sos>', 'o45', 'o46', '<unk>', 'f45', '<unk>', 'o43', 'o60', 'o65', '<unk>', 'o67', '<unk>', 'f43', '<unk>', 'f46', 'o41', '<unk>', 'f41', 'f60', '<unk>', 'f65', '<unk>', 'o36', 'o55', 'o58', 'o60', 'o65', 'o69', '<unk>', 'f55', '<unk>', 'f60', 'f65', 'f69', '<unk>', 'o60', '<unk>', 'o67', '<unk>', 'f36', '<unk>', 'o43', '<unk>', 'f60', '<unk>', 'f67', '<unk>', 'f43', 'o60', 'o65', '<unk>', 'o36', '<unk>', 'f60', '<unk>', 'f36', 'f65', '<unk>', 'o46', 'o60', 'o64', 'o64', 'o67', 'o72', '<unk>', 'f46', '<unk>', 'f67', 'o45', '<unk>', 'f67', '<unk>', 'f58', 'f64', 'f72', '<unk>', 'f60', '<unk>', 'f45', 'f64', 'o58', '<unk>', 'o43', 'o60', 'o62', 'o65', 'o65', 'o70', '<unk>', 'f43', '<unk>', 'f62', 'o41', '<unk>', 'f58', 'f65', 'f70', '<unk>', 'f41', '<unk>', 'f60', 'f65', '<unk>', 'o36', 'o55', 'o60', 'o64', 'o64', 'o67', 'o72', '<unk>', 'f55', '<unk>', 'f36', 'o60', '<unk>', 'o43', '<unk>', 'f60', '<unk>', 'f43', '<unk>', 'o36', 'o60', '<unk>', 'f36', '<unk>', 'f60']

As you can see....that's rather more <unk> than I would like...

We can do most of the tasks with this data. And you can use the transform to get torch versions of the data if required.

from mdtk.pytorch_datasets import transform_to_torchtensor
train_dataset = CommandDataset('./acme/train_cmd_corpus.csv', vocab, 128, transform=transform_to_torchtensor)
train_dataset[0]
{'deg_commands': tensor([  3,  49,  50,   1, 177,   1,  47,  64,  69,   1,  71,   1, 175,   1,
        178,  45,   1, 173, 192,   1, 197,   1,  40,  59,  62,  64,  69,  73,
          1, 187,   1, 192, 197, 201,   1,  64,   1,  71,   1, 168,   1,  47,
          1, 192,   1, 199,   1, 175,  64,  69,   1,  40,   1, 192,   1, 168,
        197,   1,  50,  64,  68,  68,  71,  76,   1, 178,   1, 199,  49,   1,
        199,   1, 190, 196, 204,   1, 192,   1, 177, 196,  62,   1,  47,  64,
         66,  69,  69,  74,   1, 175,   1, 194,  45,   1, 190, 197, 202,   1,
        173,   1, 192, 197,   1,  40,  59,  64,  68,  68,  71,  76,   1, 187,
          1, 168,  64,   1,  47,   1, 192,   1, 175,   1,  40,  64,   1, 168,
          1, 192]), 'clean_commands': tensor([  3,  49,  50,   1, 177,   1,  47,  64,  69,   1, 175,   1, 178,  45,
          1, 173, 192,   1, 197,   1,  40,  59,  62,  64,  69,  73,   1, 187,
          1, 192, 197, 201,   1,  64,   1,  71,   1, 168,   1,  47,   1, 192,
          1, 199,   1, 175,  64,  69,   1,  40,   1, 192,   1, 168, 197,   1,
         50,  64,  68,  68,  71,  76,   1, 178,   1, 199,  49,   1, 190, 196,
        204,   1, 192,   1, 177, 196,  62,   1,  47,  64,  66,  69,  69,  74,
          1, 175,   1, 194,  45,   1, 190, 197, 202,   1, 173,   1, 192, 197,
          1,  40,  59,  64,  68,  68,  71,  71,  76,   1, 187,   1, 168,  64,
          1,  47,   1, 192,   1, 175,   1,  40,  64,   1, 168,   1, 192,  50,
         68,   1]), 'deg_label': tensor(2)}

With this transformed data, the dataset can happily be used with the vanilla pytorch dataloader for batching:

from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=4,
                                                     shuffle=True, num_workers=4)
for batch_nr, batch in enumerate(train_dataloader):
    print(f'batch nr {batch_nr}')
    print(f'The batch is a {type(batch)}')
    print('But its values are pytorch tensors')
    labels = batch['deg_label']
    in_tokens = batch['deg_commands']
    print(labels.shape, in_tokens.shape)
    if batch_nr > 3:
        break
batch nr 0
The batch is a <class 'dict'>
But its values are pytorch tensors
torch.Size([4]) torch.Size([4, 128])
batch nr 1
The batch is a <class 'dict'>
But its values are pytorch tensors
torch.Size([4]) torch.Size([4, 128])
batch nr 2
The batch is a <class 'dict'>
But its values are pytorch tensors
torch.Size([4]) torch.Size([4, 128])
batch nr 3
The batch is a <class 'dict'>
But its values are pytorch tensors
torch.Size([4]) torch.Size([4, 128])
batch nr 4
The batch is a <class 'dict'>
But its values are pytorch tensors
torch.Size([4]) torch.Size([4, 128])

Finally, you can use the nn.Embedding layer as the first layer to the net to get from these integers to one hots (or learned vectors...). For an example, see how they do token embedding with BERT here (it calls out the here)

TODO:

apmcleod commented 5 years ago

lgtm. Made a few changes.

Regarding your TODOs:

-Implement rounding to nearest time_increment in df_to_command_str: DONE -handle cases where there is a longer pause than max_time_shift in df_to_command_str (just output multiple time shift commands in a row: DONE -Overlapping notes mess up the conversion back from commands to df (see issue #20 ): Complex, also see issues #46 and #8.

I'll let you check over this and complete the pull whenever you're ready. In the meantime I'll branch off of this point and continue working.

JamesOwers commented 5 years ago

Phew, that was a big one. I'm loving this process man, thanks.

I think the most important thing to sort is the overlapping pitches problem. The pipeline for dataset creation and giving an example model should determine what we should do with the --command flag. I think I'm agreed with you now really, I've no idea how people will want to use this data!

apmcleod commented 5 years ago

At the moment, on models branch, I added a --formats flag, which takes: [none, pianoroll, command] (like --datasets), and creates the pytorch Dataset csvs for whichever format you select. By default, all.