mhamilton723 / STEGO

Unsupervised Semantic Segmentation by Distilling Feature Correspondences
MIT License
707 stars 138 forks source link

tensor size mismatch error #93

Open AkankshaP0102 opened 3 months ago

AkankshaP0102 commented 3 months ago

I am trying to train STEGO on a custom dataset but during the training process if I provide labels for the corresponding images I get the following error: Traceback (most recent call last): File "train_segmentation.py", line 598, in my_app trainer.fit(model, train_loader, val_loader) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run self._dispatch() File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch self.training_type_plugin.start_training(self) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage return self._run_train() File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1311, in _run_train self._run_sanity_check(self.lightning_module) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1375, in _run_sanity_check self._evaluation_loop.run() File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, *kwargs) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(args, kwargs) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance output = self._evaluation_step(batch, batch_idx, dataloader_idx) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step output = self.trainer.accelerator.validation_step(step_kwargs) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 239, in validation_step return self.training_type_plugin.validation_step(step_kwargs.values()) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step return self.model.validation_step(args, *kwargs) File "train_segmentation.py", line 354, in validation_step self.linear_metrics.update(linear_preds, label) File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/torchmetrics/metric.py", line 405, in wrapped_func raise err File "/media/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/torchmetrics/metric.py", line 395, in wrapped_func update(args, **kwargs) File "/media/2d46715b-293d-4478-acd4-5f000d443896/stego-studies/src/utils.py", line 240, in update mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes) RuntimeError: The size of tensor a (3276800) must match the size of tensor b (10240) at non-singleton dimension 0

Please help me with the following. Thank you

Ruhrozz commented 2 months ago

I have found the solution

Most likely you have the same issue as mine. In DirectoryDataset after Image.open("mask") you have mask shape [H, W, 3]. But after that in validation step it interpolates to labels[-2:], so now predict shape is something like [B, C, H, 3].

I have solved the problem by converting mask to grayscale so that mask shape is just [H, W]:

label = Image.open(join(self.label_dir, label_fn)).convert('L')

See this and this for additional information.

AkankshaP0102 commented 1 month ago

Thank you @Ruhrozz .I tried the above solution suggested by you. It has solved the tensor error but I'm facing issue for the lable image shape. Was this type of error faced by you?? Your help will be appreciated. I have pasted the error below:

Traceback (most recent call last): File "train_segmentation.py", line 503, in my_app trainer.fit(model, train_loader, val_loader) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run self._dispatch() File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch self.training_type_plugin.start_training(self) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage return self._run_train() File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1311, in _run_train self._run_sanity_check(self.lightning_module) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1375, in _run_sanity_check self._evaluation_loop.run() File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 151, in run output = self.on_run_end() File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 131, in on_run_end self._evaluation_epoch_end(outputs) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 236, in _evaluation_epoch_end model.validation_epoch_end(outputs) File "train_segmentation.py", line 297, in validation_epoch_end ax[1, i].imshow(self.label_cmap[output["label"][i]]) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/_api/deprecation.py", line 459, in wrapper return func(*args, *kwargs) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/init.py", line 1414, in inner return func(ax, map(sanitize_sequence, args), kwargs) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/axes/_axes.py", line 5487, in imshow im.set_data(X) File "/media/emsg/2d46715b-293d-4478-acd4-5f000d443896/anaconda3/envs/stegostudies/lib/python3.7/site-packages/matplotlib/image.py", line 716, in set_data .format(self._A.shape)) TypeError: Invalid shape (1, 320, 320, 3) for image data