IyatomiLab / LeafGAN

Other
71 stars 21 forks source link

Mask transpose 2,1,0 correct? #4

Closed BartvanMarrewijk closed 2 years ago

BartvanMarrewijk commented 2 years ago

In the leaf_gan_model.py function get_masking a transpose of 2,1,0 is used to convert the mask to Cx .. x .. shape, which changes any input from HxWxC to CxWxH. I wonder if this transpose is correct and why not a tranpose of 2,0,1 is used. To convert a HxWxC array to CxHxW. Using the code below, it seems that is a mismatch between the heatmap and the orientation when using a tranpose of 2,1,0.

Another question are there any augmentations, like flipping used? I could not find them in the dataloader, but is seems that is occurs a few times. Capture

` def get_masking(self, tensor, threshold): with torch.enable_grad():

probs, idx = self.netLFLSeg.forward(tensor)

probs, idx = self.netLFLSeg.forward(tensor.to(torch.device('cuda:{}'.format(1))))
self.netLFLSeg.backward(idx=0) # 0 for getting heatmap for "fully_leaf" class

heat_map = self.netLFLSeg.generate(target_layer='layer4.2') # 'layer4.2' is the best for our experiment
heat_map = cv2.resize(heat_map, dsize=(self.opt.crop_size, self.opt.crop_size))

background_mask = np.absolute(1.0-(heat_map>=threshold))
background_mask = np.stack((background_mask, background_mask, background_mask), axis=2)

foreground_mask =  heat_map>=threshold
foreground_mask = np.stack((foreground_mask, foreground_mask, foreground_mask), axis=2)
self.plot(tensor,background_mask,heat_map)
# return background_mask.astype(np.uint8), foreground_mask.astype(np.uint8)
background_mask = background_mask.astype(np.float32).transpose(2,1,0)
foreground_mask = foreground_mask.astype(np.float32).transpose(2,1,0)
return torch.from_numpy(background_mask).unsqueeze(0).to(self.device), torch.from_numpy(foreground_mask).unsqueeze(0).to(self.device)

def plot(self,tensor,background_mask,heatmap):
import matplotlib.pyplot as plt
tensor_converted = (tensor+1)*255/2
# plt.imshow(tensor_converted[0].cpu().numpy().astype(np.uint8).transpose(2,1,0))
# plt.show()

background_mask_tmp =background_mask*255
# plt.imshow(background_mask_tmp.astype(np.uint8))  
# plt.show()

import cv2
bgr = cv2.imread(self.img_name[0])

fig, axs = plt.subplots(2, 2)
axs[0, 0].imshow(tensor_converted[0].cpu().numpy().astype(np.uint8).transpose(2,1,0))
axs[0,0].set_title('tranpose input tensor (2,1,0)')
axs[0, 1].imshow(tensor_converted[0].cpu().numpy().astype(np.uint8).transpose(1,2,0))
axs[0,1].set_title('tranpose inut tensor (1,2,0)')
# axs[0, 1].imshow(background_mask_tmp.astype(np.uint8))
axs[1,0].imshow(bgr[:,:,::-1])
axs[1,0].set_title('original')
axs[1,1].imshow(heatmap)
plt.show()

`

This results in the following image. The mask is horizontal

huuquan1994 commented 2 years ago

Hi @studentWUR,

Thanks for your questions. For the first question, you're right. The correct transpose should be (2,0,1) (i.e., from numpy image: H x W x C to torch image: C x H x W). Sorry that I didn't notice this mistake. The leaf_gan_model.py has been updated with correct transpose. Please check it.

For the second question, we follow the training procedure of the original CycleGAN repository (link) which had no data augmentation. If you want improvement, you can add your own data augmentation in your code :)

If you find any mistakes, please feel free to let me know!

BartvanMarrewijk commented 2 years ago

Hi @huuquan1994,

Thank you for updating it.

For the second question, thank you for the confirmation. Somehow I still see that one out of +-5 images is mirrored. In the example above the upper right corner and the lower left corner is mirrored. I could not find out why this occurred, do you have any ideas? image

huuquan1994 commented 2 years ago

Hi @studentWUR,

I found that the original CycleGAN implementation has the transforms.RandomHorizontalFlip() (as well as in our LeafGAN) Since the CycleGAN paper didn't describe their data augmentation technique, I assumed that they didn't use augmentation.

Turned out it actually has that option with the flag no_flip = False in the base_dataset.py (I just found out this thanks to your question :) ) By default, this value is False, and this results in a randomly horizontal flip. To disable, please set the no_flip flag to True

I hope this helps!