0nutation / SpeechGPT

SpeechGPT Series: Speech Large Language Models
https://0nutation.github.io/SpeechGPT.github.io/
Apache License 2.0
1.29k stars 86 forks source link

stage 2: dimension mismatch #43

Open ehosseiniasl opened 2 months ago

ehosseiniasl commented 2 months ago

Hello,

I encounter this issue when running cm_sft.py

0:   File "SpeechGPT/speechgpt/src/train/cm_sft_modified.py", line 460, in <module>
0:     train()
0:   File "SpeechGPT/speechgpt/src/train/cm_sft_modified.py", line 428, in train
0:     train_result = trainer.train(resume_from_checkpoint=checkpoint)
0:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1539, in train
0:     return inner_training_loop(
0:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1836, in _inner_training_loop
0:     for step, inputs in enumerate(epoch_iterator):
0:   File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 560, in __iter__
0:     next_batch, next_batch_info = self._fetch_batches(main_iterator)
0:   File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 523, in _fetch_batches
0:     batches.append(next(iterator))
0:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
0:     data = self._next_data()
0:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 674, in _next_data
0:     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
0:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
0:     data.append(next(self.dataset_iter))
0:   File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 1384, in __iter__
0:     for key, example in ex_iterable:
0:   File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 679, in __iter__
0:     yield from self._iter()
0:   File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 694, in _iter
0:     for key, example in iterator:
0:   File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 679, in __iter__
0:     yield from self._iter()
0:   File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 731, in _iter
0:     raise ValueError(
0: ValueError: Column lengths mismatch: columns ['input_ids', 'attention_mask'] have length [512, 512] while prefix has length 1000.
ehosseiniasl commented 2 months ago

I removed "prefix" column name, but then get this error

0:   File "SpeechGPT/speechgpt/src/train/cm_sft.py", line 394, in <module>
0:     train()
0:   File "SpeechGPT/speechgpt/src/train/cm_sft.py", line 362, in train
0:     train_result = trainer.train(resume_from_checkpoint=checkpoint)
0:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1539, in train
0:     return inner_training_loop(
0:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1836, in _inner_training_loop
0:     for step, inputs in enumerate(epoch_iterator):
0:   File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 560, in __iter__
0:     next_batch, next_batch_info = self._fetch_batches(main_iterator)
0:   File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 524, in _fetch_batches
0:     batch = concatenate(batches, dim=0)
0:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 496, in concatenate
0:     return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
0:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 496, in <dictcomp>
0:     return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
0:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 499, in concatenate
0:     return torch.cat(data, dim=dim)
0: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 480 but got size 512 for tensor number 1 in the list.