aeromamba-super-resolution / aeromamba

Official implementation of "AEROMamba: An efficient architecture for audio super-resolution using generative adversarial networks and state space models", presented in LAMIR 2024 Workshop
Creative Commons Zero v1.0 Universal
14 stars 1 forks source link

AEROMamba

About

Official PyTorch implementation of

AEROMamba: An efficient architecture for audio super-resolution using generative adversarial networks and state space models

whose demo is available in our Webpage. Our model is closely related to AERO and Mamba, so make sure to check them out if any questions arise regarding these modules.

Installation

Requirements:

Instructions:

If there is any error in the previous step, make sure to install manually the required libs. For PyTorch/CUDA and Mamba, manual installation is done through

Also, make sure to unzip the contents of Mamba (the mamba folder) inside aeromamba/src/models/ .

ViSQOL

We did not use ViSQOL for training and validation, but if you want to, see AERO for instructions.

Datasets

Download data

For popular music we use the mixture tracks of MUSDB18-HQ dataset.

For piano music, we collected a private dataset from CDs whose metadata are described in our Webpage.

Resample data

Data are a collection of high/low resolution pairs. Corresponding high and low resolution signals should be in different folders, eg: hr_dataset and lr_dataset.

In order to create each folder, one should run resample_data a total of 5 times, to include all source/target pairs.

We downsample once to a target 11.025 kHz, from the original 44.1 kHz.

e.g. for 11.025 and 44.1 kHz: \ python data_prep/resample_data.py --data_dir <path for 44.1 kHz data> --out_dir <path for 11.025 kHz data> --target_sr 11025

Create egs files

For each low and high resolution pair, one should create "egs files" twice: for low and high resolution.
create_meta_files.py creates a pair of train and val "egs files", each under its respective folder. Each "egs file" contains meta information about the signals: paths and signal lengths.

python data_prep/create_meta_files.py <path for 11.025 kHz data> egs/musdb/ lr

python data_prep/create_meta_files.py <path for 44.1 kHz data> egs/musdb/ hr

Train

Run train.py with dset and experiment parameters, or set the default values in main_config.yaml file.

python train.py dset=<dset-name> experiment=<experiment-name>

To train with multiple GPUs, run with parameter ddp=true. e.g. python train.py dset=<dset-name> experiment=<experiment-name> ddp=true

Test (on whole dataset)

python test.py dset=<dset-name> experiment=<experiment-name>

Inference

Single sample

python predict.py dset=<dset-name> experiment=<experiment-name> +filename=<absolute path to input file> +output=<absolute path to output directory>

Multiple samples

bash predict_batch.sh <input_folder> <output_folder>

We also provide predict_with_ola.py to predict large files that do not fit in the GPU, without the need for segmentation, using Overlap-and-Add. The original predict.py is also capable of joining predicted segments, but its naïve method causes clicks.

python predict_with_ola.py dset=<dset-name> experiment=<experiment-name> +folder_path=<absolute path to input folder> +output=<absolute path to output directory>

Checkpoints

To use pre-trained models for MUSDB18-HQ or PianoEval data, one can download checkpoints from here.

To link to checkpoint when testing or predicting, override/set path under checkpoint_file:<path> in conf/main_config.yaml. e.g.

python test.py dset=<dset-name> experiment=<experiment-name> +checkpoint_file=<path to checkpoint.th file>

Alternatively, make sure that the checkpoint file is in its corresponding output folder:
For each low to high resolution setting, hydra creates a folder under outputs/<dset-name>/<experiment-name>

Make sure that restart: false in conf/main_config.yaml