lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.32k stars 249 forks source link

IndexError Using Encodec and setting return_coarse_generated_wave=True #246

Closed rgxb2807 closed 8 months ago

rgxb2807 commented 8 months ago

Currently training the CourseTransformer using the Encodec wrapper. When setting return_coarse_generated_wave=True, I get an IndexError. This doesn't happen when it defaults to False. I'm trying this on my local machine CPU, my cuda machine is still training the CoarseTransformer.

generated_wav = audiolm(prime_wave_path=prime_wave_path, return_coarse_generated_wave=True)

I seem to be getting comparable results from my SemanticTransformer as discussed here (Thank you @eonglints) I instantiate an untrained instance of FineTransformer just so I can instantiate an AudioLM instance with the idea of developing the CourseTransformer first before moving onto the FineTransformer.

I can also instantiate an untrained instance of Soundstream and pass that as my codec to AudioLM, and when setting return_coarse_generated_wave=True, the indexing error does not happens.

My initial assumption was that something was going on with the rearrange call here When return_coarse_generated_wave=True, the shape of codes is torch.Size([3, 1, 812]) but when return_coarse_generated_wave=False, the shape of codes is torch.Size([8, 1, 390]). I'm guessing 8 vs 3 discrepancy is the number of course quantizers vs the combined course+fine quantizers.

I took a shot in the dark to try and match the the successful input dimensions by padding with 0 just to see if the indexing error goes away but it did not.

    def _decode_frame(self, quantized_indices):
        # The following code is hacked in from self.model._decode_frame() (Encodec version 0.1.1) where we assume we've
        # already unwrapped the EncodedFrame
        # Input: batch x num tokens x num quantizers
        # Output: batch x new_num_samples, where new_num_samples is num_frames * stride product (may be slightly
        # larger than original num samples as a result, because the last frame might not be "fully filled" with samples
        # if num_samples doesn't divide perfectly).
        # num_frames == the number of acoustic tokens you have, one token per frame
        codes = rearrange(quantized_indices, 'b t q -> q b t')
        if codes.shape[0] == 3:
                    pad = torch.full((8, 1, codes.shape[2]), 0)
                    pad[:codes.size(0)] = codes
                    codes = pad
        emb = self.model.quantizer.decode(codes)
        # emb shape: batch x self.model.quantizer.dimension x T. Note self.model.quantizer.dimension is the embedding dimension
        return self.model.decoder(emb)

Here's the full stack trace:


generating coarse: 100%|██████████████████████| 512/512 [08:24<00:00,  1.02it/s]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[7], line 4
      1 # generated_wav = audiolm(batch_size = 1)
      2 # generated_wav = audiolm(text=['This is a test of audio LM'])
      3 prime_wave_path = "/audio/Documents/audio_explore/test_audio_primer_mono.wav"
----> 4 generated_wav = audiolm(prime_wave_path=prime_wave_path, return_coarse_generated_wave=True)
      5 # generated_wav = audiolm(batch_size=1, return_coarse_generated_wave=True)

File ~/.env/huggingface/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.env/huggingface/lib/python3.9/site-packages/audiolm_pytorch/audiolm_pytorch.py:72, in eval_decorator.<locals>.inner(model, *args, **kwargs)
     70 was_training = model.training
     71 model.eval()
---> 72 out = fn(model, *args, **kwargs)
     73 model.train(was_training)
     74 return out

File ~/.env/huggingface/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/.env/huggingface/lib/python3.9/site-packages/audiolm_pytorch/audiolm_pytorch.py:2163, in AudioLM.forward(self, batch_size, text, text_embeds, prime_wave, prime_wave_input_sample_hz, prime_wave_path, max_length, return_coarse_generated_wave, mask_out_generated_fine_tokens)
   2153     prime_wave = prime_wave.to(self.device)
   2155 semantic_token_ids = self.semantic.generate(
   2156     text_embeds = text_embeds if self.semantic_has_condition else None,
   2157     batch_size = batch_size,
   (...)
   2160     max_length = max_length
   2161 )
-> 2163 coarse_token_ids_or_recon_wave = self.coarse.generate(
   2164     text_embeds = text_embeds if self.coarse_has_condition else None,
   2165     semantic_token_ids = semantic_token_ids,
   2166     prime_wave = prime_wave,
   2167     prime_wave_input_sample_hz = prime_wave_input_sample_hz,
   2168     reconstruct_wave = return_coarse_generated_wave
   2169 )
   2171 if return_coarse_generated_wave:
   2172     return coarse_token_ids_or_recon_wave

File ~/.env/huggingface/lib/python3.9/site-packages/audiolm_pytorch/audiolm_pytorch.py:72, in eval_decorator.<locals>.inner(model, *args, **kwargs)
     70 was_training = model.training
     71 model.eval()
---> 72 out = fn(model, *args, **kwargs)
     73 model.train(was_training)
     74 return out

File ~/.env/huggingface/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File <@beartype(audiolm_pytorch.audiolm_pytorch.CoarseTransformerWrapper.generate) at 0x129dd89d0>:78, in generate(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_4989645184, __beartype_object_4467279376, __beartype_getrandbits, *args, **kwargs)

File ~/.env/huggingface/lib/python3.9/site-packages/audiolm_pytorch/audiolm_pytorch.py:1668, in CoarseTransformerWrapper.generate(self, semantic_token_ids, prime_wave, prime_wave_input_sample_hz, prime_coarse_token_ids, text, text_embeds, max_time_steps, cond_scale, filter_thres, temperature, reconstruct_wave, use_kv_cache, **kwargs)
   1664     return sampled_coarse_token_ids
   1666 assert exists(self.codec)
-> 1668 wav = self.codec.decode_from_codebook_indices(sampled_coarse_token_ids)
   1669 return rearrange(wav, 'b 1 n -> b n')

File ~/.env/huggingface/lib/python3.9/site-packages/audiolm_pytorch/encodec.py:151, in EncodecWrapper.decode_from_codebook_indices(self, quantized_indices)
    142 assert self.model.sample_rate == 24000,\
    143     "if changing to 48kHz, that model segments its audio into lengths of 1.0 second with 1% overlap, whereas " \
    144     "the 24kHz doesn't segment at all. this means the frame decode logic might change; this is a reminder to " \
    145     "double check that."
    146 # Since 24kHz pretrained doesn't do any segmenting, we have all the frames already (1 frame = 1 token in quantized_indices)
    147 
    148 # The following code is hacked in from self.model.decode() (Encodec version 0.1.1) where we skip the part about
    149 # scaling.
    150 # Shape: 1 x (num_frames * stride product). 1 because we have 1 frame (because no segmenting)
--> 151 frames = self._decode_frame(quantized_indices)
    152 result = _linear_overlap_add(frames, self.model.segment_stride or 1)
    153 # TODO: I'm not overly pleased with this because when this function gets called, we just rearrange the result
    154 #   back to b n anyways, but we'll keep this as a temporary hack just to make things work for now

File ~/.env/huggingface/lib/python3.9/site-packages/audiolm_pytorch/encodec.py:175, in EncodecWrapper._decode_frame(self, quantized_indices)
    166 def _decode_frame(self, quantized_indices):
    167     # The following code is hacked in from self.model._decode_frame() (Encodec version 0.1.1) where we assume we've
    168     # already unwrapped the EncodedFrame
   (...)
    172     # if num_samples doesn't divide perfectly).
    173     # num_frames == the number of acoustic tokens you have, one token per frame
    174     codes = rearrange(quantized_indices, 'b t q -> q b t')
--> 175     emb = self.model.quantizer.decode(codes)
    176     # emb shape: batch x self.model.quantizer.dimension x T. Note self.model.quantizer.dimension is the embedding dimension
    177     return self.model.decoder(emb)

File ~/.env/huggingface/lib/python3.9/site-packages/encodec/quantization/vq.py:112, in ResidualVectorQuantizer.decode(self, codes)
    109 def decode(self, codes: torch.Tensor) -> torch.Tensor:
    110     """Decode the given codes to the quantized representation.
    111     """
--> 112     quantized = self.vq.decode(codes)
    113     return quantized

File ~/.env/huggingface/lib/python3.9/site-packages/encodec/quantization/core_vq.py:361, in ResidualVectorQuantization.decode(self, q_indices)
    359 for i, indices in enumerate(q_indices):
    360     layer = self.layers[i]
--> 361     quantized = layer.decode(indices)
    362     quantized_out = quantized_out + quantized
    363 return quantized_out

File ~/.env/huggingface/lib/python3.9/site-packages/encodec/quantization/core_vq.py:288, in VectorQuantization.decode(self, embed_ind)
    287 def decode(self, embed_ind):
--> 288     quantize = self._codebook.decode(embed_ind)
    289     quantize = self.project_out(quantize)
    290     quantize = rearrange(quantize, "b n d -> b d n")

File ~/.env/huggingface/lib/python3.9/site-packages/encodec/quantization/core_vq.py:202, in EuclideanCodebook.decode(self, embed_ind)
    201 def decode(self, embed_ind):
--> 202     quantize = self.dequantize(embed_ind)
    203     return quantize

File ~/.env/huggingface/lib/python3.9/site-packages/encodec/quantization/core_vq.py:188, in EuclideanCodebook.dequantize(self, embed_ind)
    187 def dequantize(self, embed_ind):
--> 188     quantize = F.embedding(embed_ind, self.embed)
    189     return quantize

File ~/.env/huggingface/lib/python3.9/site-packages/torch/nn/functional.py:2210, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2204     # Note [embedding_renorm set_grad_enabled]
   2205     # XXX: equivalent to
   2206     # with torch.no_grad():
   2207     #   torch.embedding_renorm_
   2208     # remove once script supports set_grad_enabled
   2209     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self```
rgxb2807 commented 8 months ago

ok it looks like I'm getting a different error on cuda. It's still happening in the _decode_frame() call but it appears to be happening when it tries to decode the embedding, not the quantization step.

        175 emb = self.model.quantizer.decode(codes)
        176 # emb shape: batch x self.model.quantizer.dimension x T. Note self.model.quantizer.dimension is the embedding dimension
    --> 177 return self.model.decoder(emb)

Here's the full trace


    generating coarse: 100%|██████████████████████| 512/512 [01:03<00:00,  8.12it/s]
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [102,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [103,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [104,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [105,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [106,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [107,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [108,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [109,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [110,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [111,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [112,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [113,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [114,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [115,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [116,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [117,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [118,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [119,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [120,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [121,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [122,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [123,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    ../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [59,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

    ---------------------------------------------------------------------------

    RuntimeError                              Traceback (most recent call last)

    Cell In[19], line 3
          1 # generated_wav = audiolm(batch_size = 1, return_coarse_generated_wave=True)
          2 # generated_wav = audiolm(text=['This is a test of audio LM'])
    ----> 3 generated_wav = audiolm(prime_wave_path="/audio/generated_samples/test_audio_primer.wav", return_coarse_generated_wave=True)
          6 # generated = torch.stack(generated_wav[0] + generated_wav[1])
          8 def format_stack(a,b):

    File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []

    File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/audiolm_pytorch.py:72, in eval_decorator.<locals>.inner(model, *args, **kwargs)
         70 was_training = model.training
         71 model.eval()
    ---> 72 out = fn(model, *args, **kwargs)
         73 model.train(was_training)
         74 return out

    File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
        112 @functools.wraps(func)
        113 def decorate_context(*args, **kwargs):
        114     with ctx_factory():
    --> 115         return func(*args, **kwargs)

    File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/audiolm_pytorch.py:2163, in AudioLM.forward(self, batch_size, text, text_embeds, prime_wave, prime_wave_input_sample_hz, prime_wave_path, max_length, return_coarse_generated_wave, mask_out_generated_fine_tokens)
       2153     prime_wave = prime_wave.to(self.device)
       2155 semantic_token_ids = self.semantic.generate(
       2156     text_embeds = text_embeds if self.semantic_has_condition else None,
       2157     batch_size = batch_size,
       (...)
       2160     max_length = max_length
       2161 )
    -> 2163 coarse_token_ids_or_recon_wave = self.coarse.generate(
       2164     text_embeds = text_embeds if self.coarse_has_condition else None,
       2165     semantic_token_ids = semantic_token_ids,
       2166     prime_wave = prime_wave,
       2167     prime_wave_input_sample_hz = prime_wave_input_sample_hz,
       2168     reconstruct_wave = return_coarse_generated_wave
       2169 )
       2171 if return_coarse_generated_wave:
       2172     return coarse_token_ids_or_recon_wave

    File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/audiolm_pytorch.py:72, in eval_decorator.<locals>.inner(model, *args, **kwargs)
         70 was_training = model.training
         71 model.eval()
    ---> 72 out = fn(model, *args, **kwargs)
         73 model.train(was_training)
         74 return out

    File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
        112 @functools.wraps(func)
        113 def decorate_context(*args, **kwargs):
        114     with ctx_factory():
    --> 115         return func(*args, **kwargs)

    File <@beartype(audiolm_pytorch.audiolm_pytorch.CoarseTransformerWrapper.generate) at 0x7f8c211084c0>:78, in generate(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_140239828827584, __beartype_object_94896400743232, __beartype_getrandbits, *args, **kwargs)

    File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/audiolm_pytorch.py:1668, in CoarseTransformerWrapper.generate(self, semantic_token_ids, prime_wave, prime_wave_input_sample_hz, prime_coarse_token_ids, text, text_embeds, max_time_steps, cond_scale, filter_thres, temperature, reconstruct_wave, use_kv_cache, **kwargs)
       1664     return sampled_coarse_token_ids
       1666 assert exists(self.codec)
    -> 1668 wav = self.codec.decode_from_codebook_indices(sampled_coarse_token_ids)
       1669 return rearrange(wav, 'b 1 n -> b n')

    File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/encodec.py:151, in EncodecWrapper.decode_from_codebook_indices(self, quantized_indices)
        142 assert self.model.sample_rate == 24000,\
        143     "if changing to 48kHz, that model segments its audio into lengths of 1.0 second with 1% overlap, whereas " \
        144     "the 24kHz doesn't segment at all. this means the frame decode logic might change; this is a reminder to " \
        145     "double check that."
        146 # Since 24kHz pretrained doesn't do any segmenting, we have all the frames already (1 frame = 1 token in quantized_indices)
        147 
        148 # The following code is hacked in from self.model.decode() (Encodec version 0.1.1) where we skip the part about
        149 # scaling.
        150 # Shape: 1 x (num_frames * stride product). 1 because we have 1 frame (because no segmenting)
    --> 151 frames = self._decode_frame(quantized_indices)
        152 result = _linear_overlap_add(frames, self.model.segment_stride or 1)
        153 # TODO: I'm not overly pleased with this because when this function gets called, we just rearrange the result
        154 #   back to b n anyways, but we'll keep this as a temporary hack just to make things work for now

    File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/encodec.py:177, in EncodecWrapper._decode_frame(self, quantized_indices)
        175 emb = self.model.quantizer.decode(codes)
        176 # emb shape: batch x self.model.quantizer.dimension x T. Note self.model.quantizer.dimension is the embedding dimension
    --> 177 return self.model.decoder(emb)

    File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []

    File /usr/local/lib/python3.10/dist-packages/encodec/modules/seanet.py:237, in SEANetDecoder.forward(self, z)
        236 def forward(self, z):
    --> 237     y = self.model(z)
        238     return y

    File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []

    File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
        215 def forward(self, input):
        216     for module in self:
    --> 217         input = module(input)
        218     return input

    File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []

    File /usr/local/lib/python3.10/dist-packages/encodec/modules/lstm.py:24, in SLSTM.forward(self, x)
         22 def forward(self, x):
         23     x = x.permute(2, 0, 1)
    ---> 24     y, _ = self.lstm(x)
         25     if self.skip:
         26         y = y + x

    File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []

    File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/rnn.py:812, in LSTM.forward(self, input, hx)
        810 self.check_forward_args(input, hx, batch_sizes)
        811 if batch_sizes is None:
    --> 812     result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
        813                       self.dropout, self.training, self.bidirectional, self.batch_first)
        814 else:
        815     result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
        816                       self.num_layers, self.dropout, self.training, self.bidirectional)

    RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED```
lucidrains commented 8 months ago

@rgxb2807 let me try to hack this in in the next half hour

lucidrains commented 8 months ago

@rgxb2807 want to try 1.7.5?

rgxb2807 commented 8 months ago

@lucidrains Looks awesome, I think one tiny bugfix is needed, but I'm able to generate audio with the fix. Very exciting, starting to get speech sounding output.

Here's a bugfix PR

lucidrains commented 8 months ago

merged, thanks!