tlpss / keypoint-detection

2D keypoint detection with Pytorch Lightning and wandb
MIT License
70 stars 9 forks source link

Training error when image resolution is no multiple of 32 #25

Open jarnebogaert opened 1 year ago

jarnebogaert commented 1 year ago

Output of train.py Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]Traceback (most recent call last): File "/home/jbogaert/keypoint-detection/keypoint_detection/train/train.py", line 115, in <module> main(hparams) File "/home/jbogaert/keypoint-detection/keypoint_detection/train/train.py", line 60, in main trainer.fit(model, data_module) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit call._call_and_handle_interrupt( File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl self._run(model, ckpt_path=self.ckpt_path) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run results = self._run_stage() File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage self._run_train() File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run_train self._run_sanity_check() File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1267, in _run_sanity_check val_loop.run() File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, **kwargs) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, **kwargs) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance output = self._evaluation_step(**kwargs) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step output = self.trainer._call_strategy_hook(hook_name, *kwargs.values()) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook output = fn(*args, **kwargs) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in validation_step return self.model.validation_step(*args, **kwargs) File "/home/jbogaert/keypoint-detection/keypoint_detection/models/detector.py", line 319, in validation_step result_dict = self.shared_step(val_batch, batch_idx, include_visualization_data_in_result_dict=True) File "/home/jbogaert/keypoint-detection/keypoint_detection/models/detector.py", line 220, in shared_step predicted_unnormalized_maps = self.forward_unnormalized(input_images) File "/home/jbogaert/keypoint-detection/keypoint_detection/models/detector.py", line 169, in forward_unnormalized return self.unnormalized_model(x) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/jbogaert/keypoint-detection/keypoint_detection/models/backbones/unet.py", line 102, in forward x = block(x, x_skip) File "/home/jbogaert/miniconda3/envs/keypoint-detection/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/jbogaert/keypoint-detection/keypoint_detection/models/backbones/unet.py", line 59, in forward x = torch.cat([x, x_skip], dim=1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 620 but got size 621 for tensor number 1 in the list.

Image input size was 2208x1242. Error was resolved when image was rescaled to a multiple of 32.

lucasvandijck commented 1 year ago

This is a known issue, most models are sensitive to this. You can try my PR which allows you to dynamically rescale during training without having to change the data itself.