LukeForeverYoung / UReader

Apache License 2.0
102 stars 6 forks source link

Training issue due to varying patch_position dimensions #8

Closed ambroser53 closed 8 months ago

ambroser53 commented 8 months ago

I'm trying to use your training script but struggle to get past a single loop in the trainer because the trainer within the collation function stacks every element in the batch it collates. However, the patch positions being passed are of varying sizes (most are [13,2] but some are [16,2]). Here is the full error:

Traceback (most recent call last):
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/accelerate/data_loader.py", line 384, in __iter__
    current_batch = next(dataloader_iter)
                    ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/transformers/trainer_utils.py", line 737, in __call__
    return self.data_collator(features)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/transformers/data/data_collator.py", line 70, in default_data_collator
    return torch_default_data_collator(features)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ambrose/anaconda3/envs/monster/lib/python3.11/site-packages/transformers/data/data_collator.py", line 132, in torch_default_data_collator
    batch[k] = torch.stack([f[k] for f in features])
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [16, 2] at entry 0 and [13, 2] at entry 1

Is this a known issue? I am training on custom data but have not changed any of the original code. What can I change to make this work? Will I simply have to run with a batch size of 1?

LukeForeverYoung commented 8 months ago

Our training pipeline calls the batchify to collate the batch of data (refer to the CustomTrainer). It seems that your training pipeline calls the default data_collator of huggingface trainer.

ambroser53 commented 8 months ago

This fixed the issue thank you very much.