Linwei-Chen / FreqFusion

TPAMI:Frequency-aware Feature Fusion for Dense Image Prediction
267 stars 10 forks source link

Enhancing Initial Fusion部分代码看不懂 #15

Open yx-yyds opened 1 month ago

yx-yyds commented 1 month ago

if self.semi_conv: if self.comp_feat_upsample: if self.use_high_pass: mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel, hamming=self.hamming_highpass) compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat, mask_hr_init, self.highpass_kernel, self.up_group, 1)

                mask_lr_hr_feat = self.content_encoder(compressed_hr_feat)
                mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel, hamming=self.hamming_lowpass)

                mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat)
                mask_lr_lr_feat = F.interpolate(
                    carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')
                mask_lr = mask_lr_hr_feat + mask_lr_lr_feat

                mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)
                mask_hr_lr_feat = F.interpolate(
                    carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')
                mask_hr = mask_hr_hr_feat + mask_hr_lr_feat
            else: raise NotImplementedError
        else:
            mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
            if self.use_high_pass:
                mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
    else:
        compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:], mode='nearest') + compressed_hr_feat
        mask_lr = self.content_encoder(compressed_x)
        if self.use_high_pass: 
            mask_hr = self.content_encoder2(compressed_x)

你好,这部分代码能加个注释吗,看着感觉和你论文中画的结构图不一样,没看懂

Linwei-Chen commented 1 month ago

感谢您的关注!理念是一致的,只是实现得有点绕,您也可以尝试简化。我已经更新注释,供您参考。 https://github.com/Linwei-Chen/FreqFusion/blob/main/FreqFusion.py

yx-yyds commented 1 month ago

好的,谢谢

wzhwantmoney commented 1 month ago

请问论文图中的Z,在代码中变量名是什么?

Linwei-Chen commented 1 month ago

请问论文图中的Z,在代码中变量名是什么?

感谢关注,compressed_x可以当做z