ShivamRajSharma / Vision-Transformer

Pytorch implementation of ViT on CIFAR-10.
17 stars 5 forks source link

Dimension inconsistency #3

Open sptom opened 2 years ago

sptom commented 2 years ago

After running line 134 in the ViT forward:

patches = images.unfold(2, self.patch_height, self.patch_width).unfold(3, self.patch_height, self.patch_width)

I get a tensor with sizes image

The next line is: patches = patches.permute(0, 2, 3, 1, 4, 5)

However, there's only 5 dimensions to the tensor and not 6, so corrdinate dim '5' is irrelevant and I get the error - patches = patches.permute(0, 2, 3, 1, 4, 5) RuntimeError: number of dims don't match in permute

Please help. It seems like the files are out of synch with each other. Either the file ImageTransformer found here is an older version (most likely), or the train.py file is inconsistent with it. If so, please help to provide the updated version, or explain here how to fix the dimension issue.

Thanks a lot!