Cheng-Lin-Li / SegCaps

A Clone version from Original SegCaps source code with enhancements on MS COCO dataset.
Apache License 2.0
65 stars 29 forks source link

RGB images with binary masks #21

Open fatihergin opened 4 years ago

fatihergin commented 4 years ago

Hi @Cheng-Lin-Li, thanks for your great effort.

I have a question about using custom dataset. Is it possible to train SegCapsR3 with 3 channels RGB images and 1 channel grayscale mask?

UNet works with the same configuration (3 channels input + 1 channel mask) but SegCapsR3 doesn't.

SegCapsR3 fails with the error message of __ValueError: Error when checking target: expected out_recon to have shape (224, 224, 1) but got array with shape (224, 224, 3)__

Here is the network summary. I didn't do anything on the network and I'm at the very beginning of this road and couldn't find a solution. Regards,

Layer (type)                         Output Shape              Param #   Connected to             

input_1 (InputLayer)                 (None, 224, 224, 3)       0                                  
conv1 (Conv2D)                       (None, 224, 224, 16)      1216      input_1[0][0]            
reshape_1 (Reshape)                  (None, 224, 224, 1, 16)   0         conv1[0][0]              
primarycaps (ConvCapsuleLayer)       (None, 112, 112, 2, 16)   12832     reshape_1[0][0]          
conv_cap_2_1 (ConvCapsuleLayer)      (None, 112, 112, 4, 16)   25664     primarycaps[0][0]        
conv_cap_2_2 (ConvCapsuleLayer)      (None, 56, 56, 4, 32)     51328     conv_cap_2_1[0][0]       
conv_cap_3_1 (ConvCapsuleLayer)      (None, 56, 56, 8, 32)     205056    conv_cap_2_2[0][0]       
conv_cap_3_2 (ConvCapsuleLayer)      (None, 28, 28, 8, 64)     410112    conv_cap_3_1[0][0]       
conv_cap_4_1 (ConvCapsuleLayer)      (None, 28, 28, 8, 32)     409856    conv_cap_3_2[0][0]       
deconv_cap_1_1 (DeconvCapsuleLayer)  (None, 56, 56, 8, 32)     131328    conv_cap_4_1[0][0]       
up_1 (Concatenate)                   (None, 56, 56, 16, 32)    0         deconv_cap_1_1[0][0]     
                                                                         conv_cap_3_1[0][0]       
deconv_cap_1_2 (ConvCapsuleLayer)    (None, 56, 56, 4, 32)     102528    up_1[0][0]               
deconv_cap_2_1 (DeconvCapsuleLayer)  (None, 112, 112, 4, 16)   32832     deconv_cap_1_2[0][0]     
up_2 (Concatenate)                   (None, 112, 112, 8, 16)   0         deconv_cap_2_1[0][0]     
                                                                         conv_cap_2_1[0][0]       
deconv_cap_2_2 (ConvCapsuleLayer)    (None, 112, 112, 4, 16)   25664     up_2[0][0]               
deconv_cap_3_1 (DeconvCapsuleLayer)  (None, 224, 224, 2, 16)   8224      deconv_cap_2_2[0][0]     
up_3 (Concatenate)                   (None, 224, 224, 3, 16)   0         deconv_cap_3_1[0][0]     
                                                                         reshape_1[0][0]          
seg_caps (ConvCapsuleLayer)          (None, 224, 224, 1, 16)   272       up_3[0][0]               
input_2 (InputLayer)                 (None, 224, 224, 1)       0                                  
mask_1 (Mask)                        (None, 224, 224, 1, 16)   0         seg_caps[0][0]           
                                                                         input_2[0][0]            
reshape_2 (Reshape)                  (None, 224, 224, 16)      0         mask_1[0][0]             
recon_1 (Conv2D)                     (None, 224, 224, 64)      1088      reshape_2[0][0]          
recon_2 (Conv2D)                     (None, 224, 224, 128)     8320      recon_1[0][0]            
out_seg (Length)                     (None, 224, 224, 1)       0         seg_caps[0][0]           
out_recon (Conv2D)                   (None, 224, 224, 1)       129       recon_2[0][0]