facebookresearch / pytorchvideo

A deep learning library for video understanding research.
https://pytorchvideo.org/
Apache License 2.0
3.33k stars 412 forks source link

RuntimeError: number of dims don't match in permute augmentation #165

Open talhaanwarch opened 2 years ago

talhaanwarch commented 2 years ago

I am trying to apply augmentation from torchvision. But i got an error RuntimeError: number of dims don't match in permute Here is the code

train_transform = Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform=Compose(
                  [
                    UniformTemporalSubsample(num_video_samples),
                    Lambda(lambda x: x.permute((0, 2, 3, 4, 1))), # N C T H W to N T C H W
                    ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
                    Lambda(lambda x: x.permute((0, 4, 1, 2, 3))), # N T C H W to N C T H W
                    Lambda(lambda x: x / 255.0),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(224),

                    RandomHorizontalFlip(p=0.5),
                  ]
                ),
              ),
            ]
        )

Without these permutation lines and ColorJitter the batch shape is torch.Size([8, 3, 10, 224, 224]). Here is complete error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_32219/138620282.py in <module>
      7            num_workers=4,
      8            pin_memory=True)
----> 9 batch=next(iter(train_loader))
     10 print(batch['video'].shape)

~/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

~/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1083             else:
   1084                 del self._task_info[idx]
-> 1085                 return self._process_data(data)
   1086 
   1087     def _try_put_index(self):

~/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1109         self._try_put_index()
   1110         if isinstance(data, ExceptionWrapper):
-> 1111             data.reraise()
   1112         return data
   1113 

~/venv/lib/python3.8/site-packages/torch/_utils.py in reraise(self)
    426             # have message field
    427             raise self.exc_type(message=msg)
--> 428         raise self.exc_type(msg)
    429 
    430 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/talha/venv/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/talha/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 28, in fetch
    data.append(next(self.dataset_iter))
  File "/media/talha/data/image/video rppg/pyvidcode/pytorchvideo/pytorchvideo/data/labeled_video_dataset.py", line 217, in __next__
    sample_dict = self._transform(sample_dict)
  File "/home/talha/venv/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 67, in __call__
    img = t(img)
  File "/media/talha/data/image/video rppg/pyvidcode/pytorchvideo/pytorchvideo/transforms/transforms.py", line 30, in __call__
    x[self._key] = self._transform(x[self._key])
  File "/home/talha/venv/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 67, in __call__
    img = t(img)
  File "/home/talha/venv/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 393, in __call__
    return self.lambd(img)
  File "/tmp/ipykernel_32219/654316761.py", line 10, in <lambda>
    Lambda(lambda x: x.permute((0, 2, 3, 4, 1))), # N C T H W to N T C H W
RuntimeError: number of dims don't match in permute
talhaanwarch commented 2 years ago

I think this one is correct

train_transform = Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform=Compose(
                  [
                    UniformTemporalSubsample(num_video_samples),
                    Permute([1,0,2,3]),#[ C, T, H, W]->#[ T, C, H, W]
                    ColorJitter(brightness=0.1,contrast=0.1,hue=0,saturation=0),
                    Permute([1,0,2,3]),
                    Lambda(lambda x: x / 255.0),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(224),

                    RandomHorizontalFlip(p=0.5),
                  ]
                ),
              ),
            ]
        )