sigsep / open-unmix-pytorch

Open-Unmix - Music Source Separation for PyTorch
https://sigsep.github.io/open-unmix/
MIT License
1.24k stars 181 forks source link

Training Hardware and Issue #121

Open ayush055 opened 2 years ago

ayush055 commented 2 years ago

Hi, how many RTX 2080 GPUs was the model trained on?

I am trying to use one RTX 2070 GPU with 4 workers but training takes about 10-15 minutes per epoch.

Also, I encounter the following error when validation occurs:

_pickle.PicklingError: Can't pickle <function MUSDBDataset. at 0x000001E5ECB6D5E8>: attribute lookup MUSDBDataset. on openunmix.data failed EOFError: Ran out of input

faroit commented 2 years ago

@ayush055

ayush055 commented 2 years ago

What specific hyperparameters did you use for training? I am unable to achieve such fast training times per epoch with the MUSDB18HQ dataset.

I am using the exact environment and the wav files.

faroit commented 2 years ago

Can you check your gpu utilization?

ayush055 commented 2 years ago

It is not being fully used. I am training on an A10 GPU and only about 8 GB of 24 GB memory is being used from the gpu. Increasing batch size, however, makes training take even longer.

faroit commented 2 years ago

That's a known issue caused by the musdb dataloader. I might be able to fix this soon.

In the meantime. Please use the the trackfolder_fix dataloader in the meantime. This is using pysoundfile and should be significantly faster

ayush055 commented 2 years ago

In regards to the MUSDB18HQ dataset, is this the correct command?

python train.py --root ~/musdb18hq --nb-workers 16 --dataset trackfolder_fix --target-file vocals.wav --interferer-files bass.wav drums.wav other.wav --epochs 400

With this command, I am getting about 48 seconds per epoch. However, the output printed to the terminal doesn't show the progress bar with the 344 batches. Rather, it shows the following:

image image

Is this normal?

Is there anyway to further speed this up?

faroit commented 2 years ago

Let me try this here and get back to you in the next week

ayush055 commented 2 years ago

I have tried training the model on the MUSDB18HQ dataset for vocals separation for 400 epochs but the train loss has not gone below 1. Looking at the graphs, the train loss dropped below 1 within the first few epochs. Why is this?

gzhu06 commented 2 years ago

Hi, @ayush055 , I also ran into the slow loading issue caused by musdb dataloader, so I took @faroit's advice and switch to trackfolder_fix dataloader instead.

And I got the same progressbar (6/6 in my case) as yours (a few seconds for one epoch), which is different from musdb dataloader because there's no --samples-per-track parameter. When using trackfolder_fix dataloader, it simply takes one six-second-segment (when setting --seq-dur to 6.0s) from one track.

So I followed the class MUSDBDataset and add the --samples-per-track into FixedSourcesTrackFolderDatasetduring training. And in my test, after one epoch training, these two dataloaders yield similar loss. As a reference, musdb takes around 15mins while trackfolder_fix takes 40 seconds (nb-workers=8).

Hope @faroit could fix musdb soon!

My code:

class FixedSourcesTrackFolderDataset(UnmixDataset):
    def __init__(
        self,
        root: str,
        split: str = "train",
        target_file: str = "vocals.wav",
        interferer_files: List[str] = ["bass.wav", "drums.wav"],
        seq_duration: Optional[float] = None,
        random_chunks: bool = False,
        random_track_mix: bool = False,
        samples_per_track: int = 64,
        source_augmentations: Optional[Callable] = lambda audio: audio,
        sample_rate: float = 44100.0,
        seed: int = 42) -> None:

        self.root = Path(root).expanduser()
        self.split = split
        self.sample_rate = sample_rate
        self.seq_duration = seq_duration
        self.random_track_mix = random_track_mix
        self.random_chunks = random_chunks
        self.source_augmentations = source_augmentations
        # set the input and output files (accept glob)
        self.target_file = target_file
        self.samples_per_track = samples_per_track
        self.interferer_files = interferer_files
        self.source_files = self.interferer_files + [self.target_file]
        self.seed = seed
        random.seed(self.seed)

        self.tracks = list(self.get_tracks())
        if not len(self.tracks):
            raise RuntimeError("No tracks found")

    def __getitem__(self, index):

        # select track
        track = self.tracks[index // self.samples_per_track]

        # first, get target track
        track_path = track["path"]
        min_duration = track["min_duration"]
        if self.random_chunks:
            # determine start seek by target duration
            start = random.uniform(0, min_duration - self.seq_duration)
        else:
            start = 0

        # assemble the mixture of target and interferers
        audio_sources = []
        # load target
        target_audio, _ = load_audio(track_path / self.target_file, 
                                     start=start, dur=self.seq_duration)
        target_audio = self.source_augmentations(target_audio)
        audio_sources.append(target_audio)
        # load interferers
        for source in self.interferer_files:
            # optionally select a random track for each source
            if self.random_track_mix:
                random_idx = random.choice(range(len(self.tracks)))
                track_path = self.tracks[random_idx]["path"]
                if self.random_chunks:
                    min_duration = self.tracks[random_idx]["min_duration"]
                    start = random.uniform(0, min_duration - self.seq_duration)

            audio, _ = load_audio(track_path / source, start=start, dur=self.seq_duration)
            audio = self.source_augmentations(audio)
            audio_sources.append(audio)

        stems = torch.stack(audio_sources)
        # # apply linear mix over source index=0
        x = stems.sum(0)
        # target is always the first element in the list
        y = stems[0]
        return x, y

Trainging loss using musdb dataloader

Screen Shot 2022-07-19 at 1 11 53 PM

Training loss using trackfolder_fix dataloader:

Screen Shot 2022-07-19 at 1 12 01 PM

Loss already below 1

ayush055 commented 2 years ago

Oh cool! Thanks. I was able to fix the musdb dataloader by using the load_audio function instead of the dataframe that exists. I think the bottleneck is in actually loading the audio file. Using torchaudio speeds up the process significantly and I was able to replicate the results.

gzhu06 commented 2 years ago

Yes, I think the dataloading pipeline of musdb is problematic, trackfolder_fix dataloader uses torchaudio. I didn't look into the details but I guess there's some bugs in using audio backend of the musdb package.