Closed alexrodi closed 1 year ago
I didn't thoroughly test vk diffusion (usually I go with v) but also never had exploding problems. Check that your dataset is distributed equally, i.e. if there are suddenly in a batch multiple silent samples that can mess up the model. For that, you might want to use the WAVDataset in https://github.com/archinetai/audio-data-pytorch with the check_silence
set to true
, or do some similar checks.
I am actually, using WAVDataset with check_silence
set to True
- which is the default, also, the datasets I'm using are taken from one-shot sound packs and another was a set of wavetables, so that's not the problem apparently...
Thanks for the help btw! :)
Turns out this was a good-old "too high a learn-rate" problem...
I was using default optimizer settings for lr
, betas
, eps
and weight_decay
Using the base configuration on audio-diffusion-pytorch-trainer solved this issue:
self.optimizer = torch.optim.AdamW(
params = list(self.model.parameters()),
lr = 1e-4,
betas= (0.95, 0.999),
eps= 1e-6,
weight_decay= 1e-3)
The loss suddenly increases from <0.1 to billions over one or two epochs.
I'm training an
AudioDiffusionModel
and I've had happen with both the defaultdiffusion_type='v'
as well as withdiffusion_type='vk'
, also, it happens both with and without gradient clipping. It's happened with several datasets and different batch sizes (the output below is a particularly small dataset with a large batch size)It seems to happen more often, the closer it gets to 0 loss.
Output:
The model:
Training: