Open yx-yyds opened 1 month ago
if self.semi_conv: if self.comp_feat_upsample: if self.use_high_pass: mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel, hamming=self.hamming_highpass) compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat, mask_hr_init, self.highpass_kernel, self.up_group, 1)
mask_lr_hr_feat = self.content_encoder(compressed_hr_feat) mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel, hamming=self.hamming_lowpass) mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat) mask_lr_lr_feat = F.interpolate( carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest') mask_lr = mask_lr_hr_feat + mask_lr_lr_feat mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass) mask_hr_lr_feat = F.interpolate( carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest') mask_hr = mask_hr_hr_feat + mask_hr_lr_feat else: raise NotImplementedError else: mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest') if self.use_high_pass: mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest') else: compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:], mode='nearest') + compressed_hr_feat mask_lr = self.content_encoder(compressed_x) if self.use_high_pass: mask_hr = self.content_encoder2(compressed_x)
你好,这部分代码能加个注释吗,看着感觉和你论文中画的结构图不一样,没看懂
感谢您的关注!理念是一致的,只是实现得有点绕,您也可以尝试简化。我已经更新注释,供您参考。 https://github.com/Linwei-Chen/FreqFusion/blob/main/FreqFusion.py
好的,谢谢
请问论文图中的Z,在代码中变量名是什么?
感谢关注,compressed_x可以当做z
if self.semi_conv: if self.comp_feat_upsample: if self.use_high_pass: mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel, hamming=self.hamming_highpass) compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat, mask_hr_init, self.highpass_kernel, self.up_group, 1)
你好,这部分代码能加个注释吗,看着感觉和你论文中画的结构图不一样,没看懂