AiArt-Gao / MATEBIT

[CVPR'23] Masked and Adaptive Transformer for Exemplar Based Image Translation (MATEBIT)
MIT License
87 stars 5 forks source link

Has anyone tried to reproduce for the celeba dataset? #7

Open DwanAI opened 8 months ago

DwanAI commented 8 months ago

When I change some file paths, I get RuntimeError: Given groups=1, weight of size [64, 15, 3, 3], expected input[2, 3, 258, 258] to have 15 channels, but got 3 channels instead problems

YukunZEX commented 7 months ago

i got the same problem 微信截图_20240125155256

DwanAI commented 7 months ago

Have you solved this problem, I think it's a problem with the model channel, but haven't found a way

aiartjc666 commented 7 months ago
    self.encoder_kv = FeatureGenerator(opt.semantic_nc, opt.ngf, opt.max_multi, norm='instance') this line ;
    semantic_nc change 3 channel  or pix2pix_model.py preprocess_input change Cut off 3 channels ex : [:,:3, :, :], 
aiartjc666 commented 7 months ago

of course, self.encoder_q = FeatureGenerator(opt.semantic_nc, opt.ngf, opt.max_multi, norm='instance') self.encoder_kv = FeatureGenerator(opt.semantic_nc+3, opt.ngf, opt.max_multi, norm='instance')

ref_input = torch.cat((ref_img, ref_seg_map), dim=1) out['warp_out'] = [] adaptive_feature_seg = self.encoder_q(seg_map) adaptive_feature_img = self.encoder_kv(ref_input)

    for i in range(len(adaptive_feature_seg)):
        adaptive_feature_seg[i] = util.feature_normalize(adaptive_feature_seg[i])
        adaptive_feature_img[i] = util.feature_normalize(adaptive_feature_img[i])

    if self.opt.isTrain and self.opt.weight_novgg_featpair > 0:
        real_input = torch.cat((real_img, seg_map), dim=1)
        adaptive_feature_img_pair = self.encoder_kv(real_input)
        loss_novgg_featpair = 0
        weights = [1.0, 1.0, 1.0, 1.0]
        for i in range(len(adaptive_feature_img_pair)):
            adaptive_feature_img_pair[i] = util.feature_normalize(adaptive_feature_img_pair[i])
            loss_novgg_featpair += F.l1_loss(adaptive_feature_seg[i], adaptive_feature_img_pair[i]) * weights[i]
        out['loss_novgg_featpair'] = loss_novgg_featpair * self.opt.weight_novgg_featpair
YukunZEX commented 7 months ago

i fond that the train code is runable, but the test code has this problem

aiartjc666 commented 7 months ago

if you use this code train labelnc 3. use pretrain model or train labelnc 15 use celeba generator.py ---- Replied Message ---- | From | @.> | | Date | 01/30/2024 14:45 | | To | AiArt-HDU/MATEBIT @.> | | Cc | aiartjc666 @.>, Comment @.> | | Subject | Re: [AiArt-HDU/MATEBIT] Has anyone tried to reproduce for the celeba dataset? (Issue #7) |

i fond that the train code is runable, but the test code has this problem

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

YukunZEX commented 7 months ago

thanks, i tried use generator_celebahqedge.py and i can train now, but the test code got some new promblem image

YukunZEX commented 7 months ago

by the way, can i train on my own face dataset like vox which does not have segment labels?

aiartjc666 commented 7 months ago

you use hed algorithms to get edge like metfaces

---- Replied Message ---- | From | @.> | | Date | 01/30/2024 18:07 | | To | AiArt-HDU/MATEBIT @.> | | Cc | aiartjc666 @.>, Comment @.> | | Subject | Re: [AiArt-HDU/MATEBIT] Has anyone tried to reproduce for the celeba dataset? (Issue #7) |

by the way, can i train on my own face dataset like vox which does not have segment labels?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

YukunZEX commented 7 months ago

how should i set the parameters if i have 1 channel label and 3channel image, if i set label_nc=1, it doesn't work: adaptive_feature_img = self.encoder_kv(ref_img) Given groups=1, weight of size [64, 1, 3, 3], expected input[4, 3, 258, 258] to have 1 channels, but got 3 channels instead

JcccKing commented 7 months ago

this is channel error . you can label read 3 channels, such as image.open(path).convert("RGB") # this label_nc 3 or init modules encoder_kv (self.semantic_nc + input_channels) # this label_nc 1 forward : .... self.encoder_kv(torch.cat((ref_img, seg_label),dim=1),

achaosss commented 4 months ago

generator_celebahqedge

how to use generator_celebahqedge.py