In the line number 248, average patch values are not appropriately computed.
in utils add_color_patches_rand_gt() (line number 248), the calculations is
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1)
however, I believe that this code should be changed to
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=-1,keepdim=True).view(1,C,1,1)
The previous code calculates the mean value within the ab channel, while the second calculates within the patch values.
In the line number 248, average patch values are not appropriately computed. in utils add_color_patches_rand_gt() (line number 248), the calculations is
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1)
however, I believe that this code should be changed totorch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=-1,keepdim=True).view(1,C,1,1)
The previous code calculates the mean value within the ab channel, while the second calculates within the patch values.