lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.36k stars 255 forks source link

error when i resume training the soundstream #119

Closed So-Fann closed 1 year ago

So-Fann commented 1 year ago

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/soundstream_train.py:17 in │ │ │ │ 14 │ num_train_steps = 100000 │ │ 15 ).cuda() │ │ 16 #trainer.load('results/soundstream.15000.pt') │ │ ❱ 17 trainer.train() │ │ 18 │ │ │ │ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/trainer.py:424 in │ │ train │ │ │ │ 421 │ def train(self, log_fn = noop): │ │ 422 │ │ │ │ 423 │ │ while self.steps < self.num_train_steps: │ │ ❱ 424 │ │ │ logs = self.train_step() │ │ 425 │ │ │ log_fn(logs) │ │ 426 │ │ │ │ 427 │ │ self.print('training complete') │ │ │ │ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/trainer.py:314 in │ │ train_step │ │ │ │ 311 │ │ │ wave, = next(self.dl_iter) │ │ 312 │ │ │ wave = wave.to(device) │ │ 313 │ │ │ │ │ ❱ 314 │ │ │ loss, (recon_loss, multi_spectral_reconloss, *) = self.soundstream(wave, r │ │ 315 │ │ │ │ │ 316 │ │ │ self.accelerator.backward(loss / self.grad_accum_every) │ │ 317 │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │ │ in _call_impl │ │ │ │ 1127 │ │ # this function, and just call forward. │ │ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │ │ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1130 │ │ │ return forward_call(*input, kwargs) │ │ 1131 │ │ # Do not call functions when jit is used │ │ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │ │ │ │ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/soundstream.py:546 in │ │ forward │ │ │ │ 543 │ │ x = rearrange(x, 'b c n -> b n c') │ │ 544 │ │ │ │ 545 │ │ if exists(self.encoder_attn): │ │ ❱ 546 │ │ │ x = self.encoder_attn(x) │ │ 547 │ │ │ │ 548 │ │ x, indices, commit_loss = self.rq(x) │ │ 549 │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │ │ in _call_impl │ │ │ │ 1127 │ │ # this function, and just call forward. │ │ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │ │ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1130 │ │ │ return forward_call(*input, *kwargs) │ │ 1131 │ │ # Do not call functions when jit is used │ │ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/container.py:139 │ │ in forward │ │ │ │ 136 │ # with Any as TorchScript expects a more precise type │ │ 137 │ def forward(self, input): │ │ 138 │ │ for module in self: │ │ ❱ 139 │ │ │ input = module(input) │ │ 140 │ │ return input │ │ 141 │ │ │ 142 │ def append(self, module: Module) -> 'Sequential': │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │ │ in _call_impl │ │ │ │ 1127 │ │ # this function, and just call forward. │ │ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │ │ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1130 │ │ │ return forward_call(input, kwargs) │ │ 1131 │ │ # Do not call functions when jit is used │ │ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │ │ │ │ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/soundstream.py:337 in │ │ forward │ │ │ │ 334 │ │ self.ff = FeedForward(dim = dim) │ │ 335 │ │ │ 336 │ def forward(self, x): │ │ ❱ 337 │ │ x = self.attn(x) + x │ │ 338 │ │ x = self.ff(x) + x │ │ 339 │ │ return x │ │ 340 │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │ │ in _call_impl │ │ │ │ 1127 │ │ # this function, and just call forward. │ │ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │ │ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1130 │ │ │ return forward_call(*input, *kwargs) │ │ 1131 │ │ # Do not call functions when jit is used │ │ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/local_attention/transformer.py:94 │ │ in forward │ │ │ │ 91 │ │ │ q = q self.q_scale │ │ 92 │ │ │ k = k self.k_scale │ │ 93 │ │ │ │ ❱ 94 │ │ out = self.attn_fn(q, k, v, mask = mask, attn_bias = attn_bias) │ │ 95 │ │ │ │ 96 │ │ out = rearrange(out, 'b h n d -> b n (h d)') │ │ 97 │ │ return self.to_out(out) │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │ │ in _call_impl │ │ │ │ 1127 │ │ # this function, and just call forward. │ │ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │ │ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1130 │ │ │ return forward_call(input, kwargs) │ │ 1131 │ │ # Do not call functions when jit is used │ │ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/local_attention/local_attention.p │ │ y:154 in forward │ │ │ │ 151 │ │ │ │ 152 │ │ if exists(self.rel_pos): │ │ 153 │ │ │ pos_emb, xpos_scale = self.rel_pos(bk) │ │ ❱ 154 │ │ │ bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale) │ │ 155 │ │ │ │ 156 │ │ # calculate positions for masking │ │ 157 │ │ │ │ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/local_attention/rotary.py:58 in │ │ apply_rotary_pos_emb │ │ │ │ 55 │ inv_scale = scale -1 │ │ 56 │ │ │ 57 │ if isinstance(scale, torch.Tensor): │ │ ❱ 58 │ │ scale = scale[-q_len:, :] │ │ 59 │ │ │ 60 │ q = (q q_freqs.cos() scale) + (rotate_half(q) q_freqs.sin() scale) │ │ 61 │ k = (k freqs.cos() inv_scale) + (rotate_half(k) freqs.sin() inv_scale) │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ IndexError: too many indices for tensor of dimension 1

So-Fann commented 1 year ago

why i get this error? too many indices for tensor of dimension1. I was able to train normally a few days ago

So-Fann commented 1 year ago

I found out where the problem is. The problem is that the “local-attention” package has been updated recently. Roll back to version 1.6.0 and can train normally

lucidrains commented 1 year ago

@So-Fann ah yes, that is a newly introduced bug, should be ok in local-attention 1.8.2!

So-Fann commented 1 year ago

@So-Fann ah yes, that is a newly introduced bug, should be ok in local-attention 1.8.2!

thank you!