zengxianyu / crfill

CR-Fill: Generative Image Inpainting with Auxiliary Contextual Reconstruction. ICCV 2021
Other
226 stars 43 forks source link

input for test.py includes original complete image? #6

Open mac744mail opened 3 years ago

mac744mail commented 3 years ago

hi I'm wondering if we want to do the test, what should be the input? mask+original image(complete), or masked+ masked image(incomplete)?

zengxianyu commented 3 years ago

both are fine as long as the white area in the mask covers the missing region when the input is mask+image(incomplete)

mac744mail commented 3 years ago

both are fine as long as the white area in the mask covers the missing region when the input is mask+image(incomplete)

Thanks. I tried to ran your model on my own datasets. My datasets are used to remove the contents of an article and do inpainting for the backround. The masked areas are contents of the article. 500 epochs ending with 0.069 L1c and psnr 9.715. The psnr was fluctuated around 9-10. Then the test results were like this: page_1006

page_1103

how come???something wrong? I would be grateful if you can give some comments or instructions to fix the bad results

zengxianyu commented 3 years ago

did you change the --train_image_dir and --train_image_dir in train.sh to point to your dataset?

mac744mail commented 3 years ago

did you change the --train_image_dir and --train_image_dir in train.sh to point to your dataset? Yes I think I did it correctly for own dataset. I remember at first there were some errors telling me that XXXX.png cannot be found. After I made sure all paths were correct, then the model started running. So, Im sure my paths was ok since finally the model ran soomthly without the "xxx.png cannot be found" error

zengxianyu commented 3 years ago

did you check the html file in output/ produced during training?

mac744mail commented 3 years ago

did you check the html file in output/ produced during training?

check the html file? I didnt run the flask web demo. I just trained the model and then tested. Here are two commands I used: 1, BSIZE=40 # 96:64G BSIZE0=$((BSIZE/2)) NWK=16 PREFIX="--dataset_mode_train trainimage \ --gpu_ids 0,1 \ --name debug \ --dataset_mode_val valimage \ --train_image_dir ./datasets/tr_bg_comp \ --train_image_list ./datasets/train.txt \ --path_objectshape_list ./datasets/mask_train.txt \ --path_objectshape_base ./datasets/objmask_tr \ --val_image_dir ./datasets/val_bg_comp \ --val_image_list ./datasets/val.txt \ --val_mask_dir ./datasets/objmask_val \ --no_vgg_loss \ --no_ganFeat_loss \ --load_size 1280 \ --crop_size 256 \ --model inpaint \ --netG baseconv \ --netD deepfill \ --preprocess_mode scale_shortside_and_crop \ --validation_freq 10000 \ --gpu_ids 0,1 \ --niter 500 " python train.py \ ${PREFIX} \ --batchSize ${BSIZE0} \ --nThreads ${NWK} \ --no_fine_loss \ --update_part coarse \ --no_gan_loss \ --freeze_D \ --niter 500 \ ${EXTRA}

=====The above command ran models and generated .pth weight files under checkpoint/debug folder. Then did test.py:==== 2, python test.py \ --batchSize 1 \ --nThreads 1 \ --name debug \ --dataset_mode testimage \ --image_dir ./datasets/testimg \ --mask_dir ./datasets/testmk \ --output_dir ./results \ --model inpaint \ --netG baseconv \ --which_epoch latest \

zengxianyu commented 3 years ago

there should be an HTML file produced in output by the training code. can you check it?

mac744mail commented 3 years ago

there should be an HTML file produced in output by the training code. can you check it?

yes I found it. I'm attaching it here (it has been converted into pdf) htmltopdf.pdf

zengxianyu commented 3 years ago

it seems that you deleted some lines of train.sh and only run a part of the training script

zengxianyu commented 3 years ago

it seems that you deleted some lines of train.sh and only run a part of the training script

mac744mail commented 3 years ago

it seems that you deleted some lines of train.sh and only run a part of the training script

yes, there are several "python train.py...." in train.sh. My command was the only copy of lines until the end of the first "python train.py" part, and ignored the rest. I assumed they were duplicated.

zengxianyu commented 3 years ago

they are for different stages. you'll need to use all of them for training

mac744mail commented 3 years ago

they are for different stages. you'll need to use all of them for training

well. thanks for your time. I'll try it

mac744mail commented 3 years ago

they are for different stages. you'll need to use all of them for training

btw, could you please tell me the meanings of arguments "--path_objectshape_list" and "--path_objectshape_base"? I thought "--path_objectshape_list" is the path to a txt file that includes file names of mask images, and "--path_objectshape_base" mean the path to that folder where mask images are. right?

zengxianyu commented 3 years ago

that's right and you don't need to modify these lines. Square masks and irregular masks are randomly synthesized in training. Object masks have been prepared if you run the download.sh before training

mac744mail commented 3 years ago

that's right and you don't need to modify these lines. Square masks and irregular masks are randomly synthesized in training. Object masks have been prepared if you run the download.sh before training

they are for different stages. you'll need to use all of them for training

Sorry.. but after train.sh, I got the following error when running test.sh. Could you help me?

_"RuntimeError: Error(s) in loading state_dict for BaseConvGenerator: Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2_downsample.weight", "conv2_downsample.bias", "conv3.weight", "conv3.bias", "conv4_downsample.weight", "conv4_downsample.bias", "conv5.weight", "conv5.bias", "conv6.weight", "conv6.bias", "conv7_atrous.weight", "conv7_atrous.bias", "conv8_atrous.weight", "conv8_atrous.bias", "conv9_atrous.weight", "conv9_atrous.bias", "conv10_atrous.weight", "conv10_atrous.bias", "conv11.weight", "conv11.bias", "conv12.weight", "conv12.bias", "xconv1.weight", "xconv1.bias", "xconv2_downsample.weight", "xconv2_downsample.bias", "xconv3.weight", "xconv3.bias", "xconv4_downsample.weight", "xconv4_downsample.bias", "xconv5.weight", "xconv5.bias", "xconv6.weight", "xconv6.bias", "xconv7_atrous.weight", "xconv7_atrous.bias", "xconv8_atrous.weight", "xconv8_atrous.bias", "xconv9_atrous.weight", "xconv9_atrous.bias", "xconv10_atrous.weight", "xconv10_atrous.bias", "pmconv1.weight", "pmconv1.bias", "pmconv2_downsample.weight", "pmconv2_downsample.bias", "pmconv3.weight", "pmconv3.bias", "pmconv4_downsample.weight", "pmconv4_downsample.bias", "pmconv5.weight", "pmconv5.bias", "pmconv6.weight", "pmconv6.bias", "pmconv9.weight", "pmconv9.bias", "pmconv10.weight", "pmconv10.bias", "allconv11.weight", "allconv11.bias", "allconv12.weight", "allconv12.bias", "allconv13_upsample_conv.weight", "allconv13_upsample_conv.bias", "allconv14.weight", "allconv14.bias", "allconv15_upsample_conv.weight", "allconv15_upsample_conv.bias", "allconv16.weight", "allconv16.bias", "allconv17.weight", "allconv17.bias". Unexpected key(s) in state_dict: "baseg.conv1.weight", "baseg.conv1.bias", "baseg.conv2_downsample.weight", "baseg.conv2_downsample.bias", "baseg.conv3.weight", "baseg.conv3.bias", "baseg.conv4_downsample.weight", "baseg.conv4_downsample.bias", "baseg.conv5.weight", "baseg.conv5.bias", "baseg.conv6.weight", "baseg.conv6.bias", "baseg.conv7_atrous.weight", "baseg.conv7_atrous.bias", "baseg.conv8_atrous.weight", "baseg.conv8_atrous.bias", "baseg.conv9_atrous.weight", "baseg.conv9_atrous.bias", "baseg.conv10_atrous.weight", "baseg.conv10_atrous.bias", "baseg.conv11.weight", "baseg.conv11.bias", "baseg.conv12.weight", "baseg.conv12.bias", "baseg.conv13_upsample_conv.weight", "baseg.conv13_upsample_conv.bias", "baseg.conv14.weight", "baseg.conv14.bias", "baseg.conv15_upsample_conv.weight", "baseg.conv15_upsample_conv.bias", "baseg.conv16.weight", "baseg.conv16.bias", "baseg.conv17.weight", "baseg.conv17.bias", "baseg.xconv1.weight", "baseg.xconv1.bias", "baseg.xconv2_downsample.weight", "baseg.xconv2_downsample.bias", "baseg.xconv3.weight", "baseg.xconv3.bias", "baseg.xconv4_downsample.weight", "baseg.xconv4_downsample.bias", "baseg.xconv5.weight", "baseg.xconv5.bias", "baseg.xconv6.weight", "baseg.xconv6.bias", "baseg.xconv7_atrous.weight", "baseg.xconv7_atrous.bias", "baseg.xconv8_atrous.weight", "baseg.xconv8_atrous.bias", "baseg.xconv9_atrous.weight", "baseg.xconv9_atrous.bias", "baseg.xconv10_atrous.weight", "baseg.xconv10_atrous.bias", "baseg.pmconv1.weight", "baseg.pmconv1.bias", "baseg.pmconv2_downsample.weight", "baseg.pmconv2_downsample.bias", "baseg.pmconv3.weight", "baseg.pmconv3.bias", "baseg.pmconv4_downsample.weight", "baseg.pmconv4_downsample.bias", "baseg.pmconv5.weight", "baseg.pmconv5.bias", "baseg.pmconv6.weight", "baseg.pmconv6.bias", "baseg.pmconv9.weight", "baseg.pmconv9.bias", "baseg.pmconv10.weight", "baseg.pmconv10.bias", "baseg.allconv11.weight", "baseg.allconv11.bias", "baseg.allconv12.weight", "baseg.allconv12.bias", "baseg.allconv13_upsample_conv.weight", "baseg.allconv13_upsample_conv.bias", "baseg.allconv14.weight", "baseg.allconv14.bias", "baseg.allconv15_upsample_conv.weight", "baseg.allconv15_upsample_conv.bias", "baseg.allconv16.weight", "baseg.allconv16.bias", "baseg.allconv17.weight", "baseg.allconv17.bias", "sconv1.weight", "sconv1.bias", "sconv2.weight", "sconv2.bias", "bconv1.weight", "bconv1.bias", "bconv2_downsample.weight", "bconv2_downsample.bias", "bconv3.weight", "bconv3.bias", "bconv4_downsample.weight", "bconv4_downsample.bias", "conv16_2.weight", "conv162.bias". size mismatch for conv14.weight: copying a param with shape torch.Size([96, 96, 3, 3]) from checkpoint, the shape in current model is torch.Size([96, 48, 3, 3]). size mismatch for conv16.weight: copying a param with shape torch.Size([48, 48, 3, 3]) from checkpoint, the shape in current model is torch.Size([24, 24, 3, 3]). size mismatch for conv16.bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([24]). size mismatch for conv17.weight: copying a param with shape torch.Size([3, 24, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 12, 3, 3])."

zengxianyu commented 3 years ago

it's fixed. please pull to use the latest code

all1new commented 2 years ago

你好, 我运行train.py的时候报错: ImportError: cannot import name 'Logger' from 'logger', 请问您又遇到吗