lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.42k stars 261 forks source link

RuntimeError: stack expects each tensor to be equal size, but got [5440] at entry 0 and [5120] at entry 2 #55

Closed turian closed 1 year ago

turian commented 1 year ago

The code appears to be able to handle audio of varying sizes. Indeed, librispeech contains audio of different lengths.

However, when I run on a corpus of mixed size audio, I get the following error:

[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.
Traceback (most recent call last):
  File "/root/foo.py", line 19, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 411, in train
    logs = self.train_step()
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 302, in train_step
    wave, = next(self.dl_iter)
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/trainer.py", line 70, in cycle
    for data in dl:
  File "/opt/conda/lib/python3.10/site-packages/accelerate/data_loader.py", line 375, in __iter__
    current_batch = next(dataloader_iter)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 61, in fetch
    return self.collate_fn(data)
  File "/opt/conda/lib/python3.10/site-packages/audiolm_pytorch/data.py", line 98, in inner
    data = torch.stack(data)
RuntimeError: stack expects each tensor to be equal size, but got [5440] at entry 0 and [5120] at entry 2
turian commented 1 year ago

Digging into this a bit more. The audio I'm using has samplerates 48000 and 44100.

The bug appears to be that resample happens after padding / trimming in data.py.

BTW, torchaudio writes:

"transforms.Resample precomputes and reuses the resampling kernel, so using it will result in more efficient computation if resampling multiple waveforms with the same resampling parameters."

turian commented 1 year ago

@lucidrains I started to write a PR, but because of the bug I didn't understand what the desired behavior is, if there are multiple sample-rates specified:

A few comments would be helpful :)

lucidrains commented 1 year ago

@turian hey Joseph! thanks for identifying this issue

put in a fix

and your other point is a good one, let me make sure one can specify a different target max length per resample frequency as well

lucidrains commented 1 year ago

@turian ok done, let me know if 0.4.6 works!

turian commented 1 year ago

@lucidrains Okay it works! That's wonderful. Thanks Phil

Do you mind adding a comment clarifying about the intended batch shape when there are multiple target sample rates defined?