facebookresearch / audioseal

Localized watermarking for AI-generated speech audios, with SOTA on robustness and very fast detector
MIT License
445 stars 55 forks source link

Slow training #48

Open christianc102 opened 3 months ago

christianc102 commented 3 months ago

Hi!

Thanks so much for the helpful training code and documentation. Apologies in advance for the naive question--I'm pretty new to machine learning.

I'm trying to train my own watermarking model at 48kHz with my own dataset on an H100 node with 8 GPUs (H100 80GB HBM3) on a remote SLURM cluster, but as I scale the batch size the training speed appears to drop proportionally. There also appears to be an unexpected behavior where I specify dataset.batch_size=k but the submitted config (logged by wandb) shows dataset.batch_size=k/8.

As an example, I ran experiments setting dataset.batch_size=8, which became dataset.batch_size=1, yielding a max training speed of about 1.67 steps / second and GPU utilization reaching averaging around 25%. When I set dataset.batch_size=128 (to yield dataset.batch_size=16), training speed dropped to around 0.3 steps / second. It seems to me that parallelization isn't working the way it should based on these results?

I've tried preprocessing my dataset to one-second clips and removing some of the augmentations (even running an experiment with only noise augmentations) to try to increase GPU utilization, but nothing I've tried has improved the training speed.

Is this to be expected? Roughly how long did the original model take to train, using what amount of compute?

Thank you so much!

cyrannano commented 2 months ago

@christianc102

Hi! I encountered a similar issue during my training process. One thing that helped me was updating the data hyperparameters in the solver configuration, particularly the dataset.[train, valid, evaluate].num_samples parameters. I assume when these values don't align with your dataset size, it can lead to inefficient GPU utilization.

I hope this helps!

zjcqn commented 6 days ago

I found that the pesq operation in [audiocraft/solvers/watermark.py] is very time-consuming, so I skipped it.