codeslake / Color_Transfer_Histogram_Analogy

[CGI 2020] Official PyTorch Implementation for "Deep Color Transfer using Histogram Analogy"
GNU Affero General Public License v3.0
167 stars 35 forks source link

RuntimeError: expected backend CPU and dtype Double but got backend CPU and dtype Float #6

Closed tachikoma777 closed 3 years ago

tachikoma777 commented 3 years ago

(Pytorch-1.0.0) sh-4.3$python test.py --dataroot test --checkpoints_dir checkpoint --is_SR /home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_HistogramAnalogy-master/models/networks.py:17: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal. init.normal(m.weight.data, 0.0, 0.02) /home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_HistogramAnalogy-master/models/networks.py:19: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal. init.normal(m.weight.data, 0.0, 0.02) Traceback (most recent call last): File "test.py", line 24, in for i, data in enumerate(dataset): File "/home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_Histogram_Analogy-master/data/custom_dataset_data_loader.py", line 33, in iter for i, data in enumerate(self.dataloader): File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 582, in next return self._process_next_batch(batch) File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch raise batch.exc_type(batch.exc_msg) RuntimeError: Traceback (most recent call last): File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in samples = collate_fn([dataset[i] for i in batch_indices]) File "/home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_Histogram_Analogy-master/data/aligned_dataset_rand_seg_onlymap.py", line 48, in getitem A = self.transform_type(Aimg) File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 61, in call img = t(img) File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 164, in call return F.normalize(tensor, self.mean, self.std, self.inplace) File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 208, in normalize tensor.sub(mean[:, None, None]).div_(std[:, None, None]) RuntimeError: expected backend CPU and dtype Double but got backend CPU and dtype Float

tachikoma777 commented 3 years ago

Thanks for the great work! Some error occured while i run test.py, and the output is as described, anyone knows why?

And after I swich to torch1.8, another error occured. python test.py --dataroot test --checkpoints_dir checkpoint --is_SR /home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_HistogramAnalogy-master/models/networks.py:17: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal. init.normal(m.weight.data, 0.0, 0.02) /home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_HistogramAnalogy-master/models/networks.py:19: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal. init.normal(m.weight.data, 0.0, 0.02) Traceback (most recent call last): File "test.py", line 24, in for i, data in enumerate(dataset): File "/home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_Histogram_Analogy-master/data/custom_dataset_data_loader.py", line 33, in iter for i, data in enumerate(self.dataloader): File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 517, in next data = self._next_data() File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1199, in _next_data return self._process_data(data) File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1225, in _process_data data.reraise() File "/opt/conda/lib/python3.7/site-packages/torch/_utils.py", line 429, in reraise raise self.exc_type(msg) RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop data = fetcher.fetch(index) File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/ma-user/work/UD_ISP/color_transfer/Color_Transfer_Histogram_Analogy-master/data/aligned_dataset_rand_seg_onlymap.py", line 56, in getitem B_map=self.transform_no(Image.open(B_path_map)) File "/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 60, in call img = t(img) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _callimpl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 221, in forward return F.normalize(tensor, self.mean, self.std, self.inplace) File "/opt/conda/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 336, in normalize tensor.sub(mean).div_(std) RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

codeslake commented 3 years ago

Did you try with the same tested environment as ours?

tachikoma777 commented 3 years ago

seems to be environment problem

tiagomfmadeira commented 9 months ago

If someone happens to find this and has the problem "RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0". This means the maps have 4 channels (RGBA) instead of just RGB.

Here's what to do:

Go to data/aligned_dataset_rand_seg_onlymap.py Over in line 55 and 56

A_map=self.transform_no(Image.open(A_path_map))
B_map=self.transform_no(Image.open(B_path_map))

change to

A_map=self.transform_no(Image.open(A_path_map).convert('RGB'))
B_map=self.transform_no(Image.open(B_path_map).convert('RGB'))

Maybe @codeslake can add this change, or accept a PR.

Cheers,