archinetai / audio-data-pytorch

A collection of useful audio datasets and transforms for PyTorch.
MIT License
129 stars 22 forks source link

feat: improve WAV read speed & add artist/genre info from ID3 tags #3

Closed zaptrem closed 1 year ago

zaptrem commented 1 year ago

Description

This pull request introduces several improvements to the WAVDataset class:

First, the reading of WAV files has been optimized by only reading the relevant data rather than the entire file. This can significantly improve the reading speed for large datasets.

Second, the TensorBackedImmutableStringArray class has been introduced to improve the memory usage and startup speed of worker processes. This class stores a list of strings as a tensor, and isn't copied on read. Source: https://gist.github.com/vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57

Finally, the WAVDataset class now has the option to extract artist and genre information from the ID3 tags of the WAV files, and maintain a data structure with integer representations of these artists and genres. This can be useful for training generative models conditioned on this metadata.

Examples

Here is an example of how to use the updated WAVDataset class to extract artist and genre information from the ID3 tags of WAV files:

from torch.utils.data import DataLoader

# Create a WAVDataset instance with artist and genre information
dataset = WAVDataset(
    '/path/to/wav/files',
    metadata_mapping_path='/path/to/metadata/mapping.json' #will be created if it doesn't exist
)

# Use a DataLoader to load the dataset in batches
dataloader = DataLoader(dataset, batch_size=4, num_workers=4)

# Iterate over the dataloader to access the data
for data in dataloader:
    # data will contain a tuple of the audio tensor, the artist ID tensor, and the genre ID tensor
    audio, artist_genre_ids_tensor = data

You can also use the mappings attribute of the WAVDataset instance to convert the artist and genre IDs back to their original names:

# Get the artist name from the ID
artist_name = dataset.mappings['artists'].invert[artist_id]

# Get the genre name from the ID
genre_name = dataset.mappings['genres'].invert[genre_id]

ToDo

flavioschneider commented 1 year ago

A few comments:

  1. Is TensorBackedImmutableStringArray really necessary? Even if you have 10k songs, a list with 10k strings shouldn't be that heavy to store in memory.
  2. I don't like that the random crop has been removed from the all transform, which is also used in preprocessing code. Adding crop_size and random_crop_size as an additional parameter in the WAVDataset so that optimizations can be applied is more appropriate, then the user can choose to remove it from the AllTransform by not setting it, and using the one from WAVDataset instead.
  3. I want to keep WAVDataset as simple and easy to understand as possible, optimizations are allowed if the difference is significant, but things like extracting custom artist/album information is specific to your dataset and use case and adds additional complexity. A more appropriate way would be to extend the WAVDataset class as other datasets do.

Related to style:

zaptrem commented 1 year ago

A few comments:

  1. Is TensorBackedImmutableStringArray really necessary? Even if you have 10k songs, a list with 10k strings shouldn't be that heavy to store in memory.
  2. I don't like that the random crop has been removed from the all transform, which is also used in preprocessing code. Adding crop_size and random_crop_size as an additional parameter in the WAVDataset so that optimizations can be applied is more appropriate, then the user can choose to remove it from the AllTransform by not setting it, and using the one from WAVDataset instead.
  3. I want to keep WAVDataset as simple and easy to understand as possible, optimizations are allowed if the difference is significant, but things like extracting custom artist/album information is specific to your dataset and use case and adds additional complexity. A more appropriate way would be to extend the WAVDataset class as other datasets do.

Related to style:

  • Activate pre-commit hooks, so that the formatting and checking is standardized.
  • Avoid too many indentations, those make the code hard to read and understand.
  • If a function is too long, probably you need to split it into multiple function.
  • Keep everything as simple as possible.
  • Code in a way that explains what you're doing, either with good function names or comments if something doesn't explain itself.

Ok, fixed all the above and removed TensorBackedImmutableStringArray (efficiency is an obsession, sorry!). Also factored out all of my metadata stuff and made it return strings by default (probably more useful to everyone doing text conditioning). I did add one more dependency (bidict) so we can decode the artist/genre IDs efficiently for logging. Also updated README.

Re: pre-commit

(base) zaptrem@Holmes:~/dance3/audio-data-pytorch$ pre-commit
[INFO] Initializing environment for https://gitlab.com/pycqa/flake8.
Username for 'https://gitlab.com': zaptrem1
Password for 'https://zaptrem1@gitlab.com':
Username for 'https://gitlab.com': zaptrem1
Password for 'https://zaptrem1@gitlab.com':
An unexpected error has occurred: CalledProcessError: command: ('/usr/bin/git', 'fetch', 'origin', '--tags')
return code: 128
stdout: (none)
stderr:
    remote: The project you were looking for could not be found or you don't have permission to view it.
    fatal: repository 'https://gitlab.com/pycqa/flake8.git/' not found

Check the log at /home/zaptrem/.cache/pre-commit/pre-commit.log

Visiting above page results in 404, edit: so I switched the pre-commit thing to use the GitHub flake8 repo instead. Seems to work fine. Ran into below error while pre-committing but I use black as a linter so I should already pass whatever trials and tribulations it had in store for me:

black....................................................................Failed
- hook id: black
- exit code: 1

Traceback (most recent call last):
  File "/home/zaptrem/.cache/pre-commit/repop4pv2akh/py_env-python3.9/bin/black", line 8, in <module>
    sys.exit(patched_main())
  File "/home/zaptrem/.cache/pre-commit/repop4pv2akh/py_env-python3.9/lib/python3.9/site-packages/black/__init__.py", line 1372, in patched_main
    patch_click()
  File "/home/zaptrem/.cache/pre-commit/repop4pv2akh/py_env-python3.9/lib/python3.9/site-packages/black/__init__.py", line 1358, in patch_click
    from click import _unicodefun
ImportError: cannot import name '_unicodefun' from 'click' (/home/zaptrem/.cache/pre-commit/repop4pv2akh/py_env-python3.9/lib/python3.9/site-packages/click/__init__.py)

Note: In order to appease the type checker I switched idx from a list/int to an int as it doesn't actually seem to be used as a list anywhere?