salu133445 / muspy

A toolkit for symbolic music generation
https://salu133445.github.io/muspy/
MIT License
435 stars 51 forks source link

Handle training-validation-test splits in NESMusicDatabase #29

Open salu133445 opened 3 years ago

salu133445 commented 3 years ago

The current implementation of NESMusicDatabase does not handle the training-validation-test splits provided in the original dataset. To avoid changing the base Dataset class too much, we could add a subset method and achieve something like the following.

nes = muspy.NESMusicDatabase("data/nes/")

training_set = nes.subset("training")  # also a Dataset object
validation_set = nes.subset("validation")
test_set = nes.subset("test")
cifkao commented 3 years ago

I also need to load a dataset that has splits. I added a part parameter to my constructor:

class Groove2GrooveDataset(muspy.RemoteFolderDataset):

    ...

    def __init__( 
        self,
        root: Union[str, Path],
        download_and_extract: bool = False,
        cleanup: bool = False,
        convert: bool = False,
        kind: str = "json",
        n_jobs: int = 1,
        ignore_exceptions: bool = True,
        use_converted: Optional[bool] = None,
        part: str = "train"
    ):
        muspy.RemoteFolderDataset.__init__(
            self, root=root, download_and_extract=download_and_extract,
            cleanup=cleanup, convert=convert, kind=kind, n_jobs=n_jobs,
            ignore_exceptions=ignore_exceptions, use_converted=use_converted)

        path = self.root / 'groove2groove-data-v1.0.0' / 'midi' / part / 'fixed'
        self.raw_filenames = sorted(
            (
                filename
                for filename in path.rglob("*." + self._extension)
            )
        )
        self._filenames = self.raw_filenames

However, this doesn't work:

>>> test_data = Groove2GrooveDataset('/tmp/groove2groove-data', part='test', download_and_extract=True)  # OK
>>> test_data.convert()  # OK
>>> val_data = Groove2GrooveDataset('/tmp/groove2groove-data', part='val')  # OK, reuses downloaded data
>>> val_data.convert()  # not OK, skips conversion as '_converted' already contains the test data

Edit: Overriding converted_dir fixed it:

    @property
    def converted_dir(self):
        return self.root / "_converted_{}".format(self.part)
salu133445 commented 3 years ago

I see your point. We could have a Subset class for this, which does not have a convert method but can be iterated over just like a regular dataset. The key is that the subset always share the data with its parent dataset and the only difference is in the filenames to look for.

nes = muspy.NESMusicDatabase("data/nes/")
nes.convert()
training_set = nes.subset("training")

And this won't work

nes = muspy.NESMusicDatabase("data/nes/")
training_set = nes.subset("training")
training_set.convert()  # error