richzhang / colorization-pytorch

PyTorch reimplementation of Interactive Deep Colorization
https://richzhang.github.io/ideepcolor/
MIT License
598 stars 114 forks source link

I found the error in util.py #22

Closed niceDuckgu closed 2 years ago

niceDuckgu commented 2 years ago

In the line number 248, average patch values are not appropriately computed. in utils add_color_patches_rand_gt() (line number 248), the calculations is torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1) however, I believe that this code should be changed to torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=-1,keepdim=True).view(1,C,1,1)

The previous code calculates the mean value within the ab channel, while the second calculates within the patch values.

niceDuckgu commented 2 years ago

Actually, this error only occurs in the 1CH*W image. Original code has no problem with a batch size of more than 2.