Closed So-Fann closed 1 year ago
why i get this error? too many indices for tensor of dimension1. I was able to train normally a few days ago
I found out where the problem is. The problem is that the “local-attention” package has been updated recently. Roll back to version 1.6.0 and can train normally
@So-Fann ah yes, that is a newly introduced bug, should be ok in local-attention 1.8.2!
@So-Fann ah yes, that is a newly introduced bug, should be ok in local-attention 1.8.2!
thank you!
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/soundstream_train.py:17 in │
│ │
│ 14 │ num_train_steps = 100000 │
│ 15 ).cuda() │
│ 16 #trainer.load('results/soundstream.15000.pt') │
│ ❱ 17 trainer.train() │
│ 18 │
│ │
│ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/trainer.py:424 in │
│ train │
│ │
│ 421 │ def train(self, log_fn = noop): │
│ 422 │ │ │
│ 423 │ │ while self.steps < self.num_train_steps: │
│ ❱ 424 │ │ │ logs = self.train_step() │
│ 425 │ │ │ log_fn(logs) │
│ 426 │ │ │
│ 427 │ │ self.print('training complete') │
│ │
│ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/trainer.py:314 in │
│ train_step │
│ │
│ 311 │ │ │ wave, = next(self.dl_iter) │
│ 312 │ │ │ wave = wave.to(device) │
│ 313 │ │ │ │
│ ❱ 314 │ │ │ loss, (recon_loss, multi_spectral_reconloss, *) = self.soundstream(wave, r │
│ 315 │ │ │ │
│ 316 │ │ │ self.accelerator.backward(loss / self.grad_accum_every) │
│ 317 │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │
│ in _call_impl │
│ │
│ 1127 │ │ # this function, and just call forward. │
│ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1130 │ │ │ return forward_call(*input, kwargs) │
│ 1131 │ │ # Do not call functions when jit is used │
│ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/soundstream.py:546 in │
│ forward │
│ │
│ 543 │ │ x = rearrange(x, 'b c n -> b n c') │
│ 544 │ │ │
│ 545 │ │ if exists(self.encoder_attn): │
│ ❱ 546 │ │ │ x = self.encoder_attn(x) │
│ 547 │ │ │
│ 548 │ │ x, indices, commit_loss = self.rq(x) │
│ 549 │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │
│ in _call_impl │
│ │
│ 1127 │ │ # this function, and just call forward. │
│ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1130 │ │ │ return forward_call(*input, *kwargs) │
│ 1131 │ │ # Do not call functions when jit is used │
│ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/container.py:139 │
│ in forward │
│ │
│ 136 │ # with Any as TorchScript expects a more precise type │
│ 137 │ def forward(self, input): │
│ 138 │ │ for module in self: │
│ ❱ 139 │ │ │ input = module(input) │
│ 140 │ │ return input │
│ 141 │ │
│ 142 │ def append(self, module: Module) -> 'Sequential': │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │
│ in _call_impl │
│ │
│ 1127 │ │ # this function, and just call forward. │
│ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1130 │ │ │ return forward_call(input, kwargs) │
│ 1131 │ │ # Do not call functions when jit is used │
│ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /group/30042/sofanzhou/code/audiolm/audiolm-pytorch-0.12.5/audiolm_pytorch/soundstream.py:337 in │
│ forward │
│ │
│ 334 │ │ self.ff = FeedForward(dim = dim) │
│ 335 │ │
│ 336 │ def forward(self, x): │
│ ❱ 337 │ │ x = self.attn(x) + x │
│ 338 │ │ x = self.ff(x) + x │
│ 339 │ │ return x │
│ 340 │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │
│ in _call_impl │
│ │
│ 1127 │ │ # this function, and just call forward. │
│ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1130 │ │ │ return forward_call(*input, *kwargs) │
│ 1131 │ │ # Do not call functions when jit is used │
│ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/local_attention/transformer.py:94 │
│ in forward │
│ │
│ 91 │ │ │ q = q self.q_scale │
│ 92 │ │ │ k = k self.k_scale │
│ 93 │ │ │
│ ❱ 94 │ │ out = self.attn_fn(q, k, v, mask = mask, attn_bias = attn_bias) │
│ 95 │ │ │
│ 96 │ │ out = rearrange(out, 'b h n d -> b n (h d)') │
│ 97 │ │ return self.to_out(out) │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 │
│ in _call_impl │
│ │
│ 1127 │ │ # this function, and just call forward. │
│ 1128 │ │ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o │
│ 1129 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1130 │ │ │ return forward_call(input, kwargs) │
│ 1131 │ │ # Do not call functions when jit is used │
│ 1132 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1133 │ │ if self._backward_hooks or _global_backward_hooks: │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/local_attention/local_attention.p │
│ y:154 in forward │
│ │
│ 151 │ │ │
│ 152 │ │ if exists(self.rel_pos): │
│ 153 │ │ │ pos_emb, xpos_scale = self.rel_pos(bk) │
│ ❱ 154 │ │ │ bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale) │
│ 155 │ │ │
│ 156 │ │ # calculate positions for masking │
│ 157 │
│ │
│ /data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/local_attention/rotary.py:58 in │
│ apply_rotary_pos_emb │
│ │
│ 55 │ inv_scale = scale -1 │
│ 56 │ │
│ 57 │ if isinstance(scale, torch.Tensor): │
│ ❱ 58 │ │ scale = scale[-q_len:, :] │
│ 59 │ │
│ 60 │ q = (q q_freqs.cos() scale) + (rotate_half(q) q_freqs.sin() scale) │
│ 61 │ k = (k freqs.cos() inv_scale) + (rotate_half(k) freqs.sin() inv_scale) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: too many indices for tensor of dimension 1