lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.32k stars 249 forks source link

`data_max_length_seconds` causes typecheck error in `CoarseTransformerTrainer` #259

Closed orrp closed 6 months ago

orrp commented 6 months ago

When creating a CoarseTransformerTrainer with data_max_length_seconds (rather than data_max_length), you will get the following type error:

File "audiolm-pytorch/train.py", line 48, in make_coarse_trainer
    trainer = CoarseTransformerTrainer(
  File "<@beartype(audiolm_pytorch.trainer.CoarseTransformerTrainer.__init__) at 0x7f056897a4c0>", line 196, in __init__
  File "audiolm-pytorch/audiolm_pytorch/trainer.py", line 1090, in __init__
    self.ds = SoundDataset(
  File "<@beartype(audiolm_pytorch.data.SoundDataset.__init__) at 0x7f05689709d0>", line 58, in __init__
beartype.roar.BeartypeCallHintParamViolation: Method audiolm_pytorch.data.SoundDataset.__init__() parameter max_length=(160000, 240000) violates type hint typing.Optional[int], as tuple (160000, 240000) not <class "builtins.NoneType"> or int.

This is because, indeed, SoundDataset expects max_length of type either None or int. But CoarseTransformerTrainer attempts to pass it a tuple. In line trainer.py:1089 we have:

            if exists(data_max_length_seconds):
                data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, codec.target_sample_hz))

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,

Note that this issue is only for Coarse. For convenience, I list how data_max_length is derived from data_max_length_seconds in the other trainers:

data_max_length = int(data_max_length_seconds * soundstream.target_sample_hz) # SoundStreamTrainer
data_max_length = data_max_length_seconds * wav2vec.target_sample_hz # SemanticTransformerTrainer
data_max_length = data_max_length = data_max_length_seconds * codec.target_sample_hz # FineTransformerTrainer

In the SoundDataset constructor we have

max_length: Optional[int] = None,               # max length would apply to the highest target_sample_hz, if there are multiple

Therefore my guess is that CoarseTransformerTrainer should pass the max() rather than the tuple itself. But perhaps it would be cleaner to change the type to Union[int, Tuple[int, ...]] (as in target_sample_hz) and take the max() in the SoundDataset construct, if a tuple is passed.

Thoughts?