TPAMI 2024:Frequency-aware Feature Fusion for Dense Image Prediction
The preliminary official implementation of our TPAMI 2024 paper "Frequency-aware Feature Fusion for Dense Image Prediction", which is also available at https://github.com/ying-fu/FreqFusion.
Interested readers are also referred to an insightful [Note]() about this work in Zhihu (TODO).
[TOC]
Dense image prediction tasks demand features with strong category information and precise spatial boundary details at high resolution. To achieve this, modern hierarchical models often utilize feature fusion, directly adding upsampled coarse features from deep layers and high-resolution features from lower levels. In this paper, we observe rapid variations in fused feature values within objects, resulting in intra-category inconsistency due to disturbed high-frequency features. Additionally, blurred boundaries in fused features lack accurate high frequency, leading to boundary displacement. Building upon these observations, we propose Frequency-Aware Feature Fusion (FreqFusion), integrating an Adaptive Low-Pass Filter (ALPF) generator, an offset generator, and an Adaptive High-Pass Filter (AHPF) generator. The ALPF generator predicts spatially-variant low-pass filters to attenuate high-frequency components within objects, reducing intra-class inconsistency during upsampling. The offset generator refines large inconsistent features and thin boundaries by replacing inconsistent features with more consistent ones through resampling, while the AHPF generator enhances high-frequency detailed boundary information lost during downsampling. Comprehensive visualization and quantitative analysis demonstrate that FreqFusion effectively improves feature consistency and sharpens object boundaries. Extensive experiments across various dense prediction tasks confirm its effectiveness.
The clean code for FreqFusion is available here. By utilizing their frequency properties, FreqFusion is capable of enhancing the quality of both low and high-resolution features (referred to as lr_feat
and hr_feat
, respectively, with the assumption that the size of hr_feat
is twice that of lr_feat
). The usage is very simple.
ff = FreqFusion(hr_channels=64, lr_channels=64)
hr_feat = torch.rand(1, 64, 32, 32)
lr_feat = torch.rand(1, 64, 16, 16)
_, hr_feat, lr_feat = ff(hr_feat=hr_feat, lr_feat=lr_feat) # lr_feat [1, 64, 32, 32]
Where should I integrate FreqFusion?
You should integrate FreqFusion wherever you need to perform upsampling. FreqFusion is capable of fully utilizing both low and high-resolution features, it can very effectively recover high-resolution, semantically accurate features from low-resolution high-level features, while enhancing the details of high-resolution low-level features.
Example of the concat version for feature fusion (SegNeXt, SegFormer):
You can refer to ham_head.py.
x1, x2, x3, x4 = backbone(img) #x1, x2, x3, x4 in 1/4, 1/8, 1/16, 1/32
x1, x2, x3, x4 = conv1x1(x1), conv1x1(x2), conv1x1(x3), conv1x1(x4) # channel=c
ff1 = FreqFusion(hr_channels=c, lr_channels=c)
ff2 = FreqFusion(hr_channels=c, lr_channels=2 * c)
ff3 = FreqFusion(hr_channels=c, lr_channels=3 * c)
_, x3, x4_up = ff1(hr_feat=x3, lr_feat=x4)
_, x2, x34_up = ff2(hr_feat=x2, lr_feat=torch.cat([x3, x4_up]))
_, x1, x234_up = ff3(hr_feat=x1, lr_feat=torch.cat([x2, x34_up]))
x1234 = torch.cat([x1, x234_up] # channel=4c, 1/4 img size
Another example of the concat version for feature fusion (You may try for UNet):
x1, x2, x3, x4 = backbone(img) #x1, x2, x3, x4 in 1/4, 1/8, 1/16, 1/32
x1, x2, x3, x4 = conv1x1(x1), conv1x1(x2), conv1x1(x3), conv1x1(x4) # conv1x1s in original FPN to align channel=c
ff1 = FreqFusion(hr_channels=c, lr_channels=c)
ff2 = FreqFusion(hr_channels=c, lr_channels=c)
ff3 = FreqFusion(hr_channels=c, lr_channels=c)
y4 = x4 # channel=c
_, x3, y4_up = ff1(hr_feat=x3, lr_feat=y4)
y3 = conv(torch.cat([x3 + y4_up])) # channel=c
_, x2, y3_up = ff2(hr_feat=x2, lr_feat=y3)
y2 = conv(torch.cat([x2 + y3_up])) # channel=c
_, x2, y2_up = ff3(hr_feat=x1, lr_feat=y2)
y1 = conv(torch.cat([x1 + y2_up])) # channel=c
Example of the add version for feature fusion (FPN-based methods):
You can refer to FPN.py.
x1, x2, x3, x4 = backbone(img) #x1, x2, x3, x4 in 1/4, 1/8, 1/16, 1/32
x1, x2, x3, x4 = conv1x1(x1), conv1x1(x2), conv1x1(x3), conv1x1(x4) # conv1x1s in original FPN to align channel=c
ff1 = FreqFusion(hr_channels=c, lr_channels=c)
ff2 = FreqFusion(hr_channels=c, lr_channels=c)
ff3 = FreqFusion(hr_channels=c, lr_channels=c)
y4 = x4
_, x3, y4_up = ff1(hr_feat=x3, lr_feat=y4)
y3 = x3 + y4_up
_, x2, y3_up = ff2(hr_feat=x2, lr_feat=y3)
y2 = x2 + y3_up
_, x2, y2_up = ff3(hr_feat=x1, lr_feat=y2)
y1 = x1 + y2_up
The FreqFusion relies on mmcv libarary, you can install mmcv-full by:
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.5.3 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11/index.html
You can refer to https://mmcv.readthedocs.io/en/v1.7.0/get_started/installation.html and select the appropriate installation command depending on the type of system, CUDA version, PyTorch version, and MMCV version.
Tips:
MMData installation may be annoying, and although the adaptive low/high-filter in FreqFusion can use torch.nn.functional.unfold
as a replacement (you can try), it may consume a large amount of GPU memory. Therefore, I suggest using MMData for efficiency.
Code of SegNeXt
Core modification:
Method | Backbone | Crop Size | Lr Schd | mIoU |
---|---|---|---|---|
SegNeXt | MSCAN-T | 512x512 | 160k | 41.1 |
SegNeXt + FreqFusion | MSCAN-T | 512x512 | 160k | 43.5 |
Checkpoint:
Method | Backbone | mIoU | Configs | Links |
---|---|---|---|---|
SegNeXt + FreqFusion | MSCAN-T | 43.7 (43.5 in paper) | config | ckpt (code: PAMI) |
Note:
The original SegNeXt code can be found here.
Our code is based on MMSegmentation. You can install mmseg by:
pip install mmsegmentation==0.24.1
Please refer to get_started.md for more details on installation, and dataset_prepare.md for information on dataset preparation. For further details on code usage, you can refer to this.
You can install mmcv-full by:
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.5.3 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11/index.html
For more details on installing and using SegNeXt, please refer to the README file.
Core modification:
Mask2Former | Backbone | mIoU |
---|---|---|
Bilinear | Swin-B | 53.9 |
FreqFusion (Ours) | Swin-B | 55.3 (+1.4) |
Bilinear | Swin-L | 56.1 |
FreqFusion (Ours) | Swin-L | 56.8 (+0.7) |
Checkpoint:
Mask2Former | Backbone | mIoU | Configs | Links |
---|---|---|---|---|
FreqFusion | Swin-B | 55.7 (55.3 in paper) | config | ckpt (code: PAMI) |
FreqFusion | Swin-L | 57.0 (56.8 in paper) | config | ckpt (code: PAMI) |
Note:
Install Mask2Former.
See Preparing Datasets for Mask2Former.
See Getting Started with Mask2Former.
See installation instructions.
For more details on installing and using Mask2Former, please refer to the README file.
Code for Faster R-CNN, Mask R-CNN, Panoptic FPN: Here (mmdet==2.28.1)
Faster R-CNN(Detection) | Backbone | AP |
---|---|---|
Nearest | R50 | 37.5 |
Deconv | R50 | 37.3 |
PixelShuffle | R50 | 37.5 |
CARAFE | R50 | 38.6 |
IndexNet | R50 | 37.6 |
A2U | R50 | 37.3 |
FADE | R50 | 38.5 |
SAPA-B | R50 | 37.8 |
DySample-S+ | R50 | 38.6 |
DySample+ | R50 | 38.7 |
FreqFusion (Ours) | R50 | 39.4 |
Nearest | R101 | 39.4 |
DySample+ | R101 | 40.5 |
FreqFusion (Ours) | R101 | 41.0 |
Checkpoint:
Faster R-CNN | Backbone | Box AP | Configs | Links |
---|---|---|---|---|
FreqFusion | ResNet-50 | 39.5 (39.4 in paper) | config | ckpt (code: PAMI) |
FreqFusion | ResNet-101 | 41.1 (41.0 in paper) | config | ckpt (code: PAMI) |
Model | Backbone | Box AP | Mask AP |
---|---|---|---|
Nearest | R50 | 38.3 | 34.7 |
Deconv | R50 | 37.9 | 34.5 |
PixelShuffle | R50 | 38.5 | 34.8 |
CARAFE | R50 | 39.2 | 35.4 |
IndexNet | R50 | 38.4 | 34.7 |
A2U | R50 | 38.2 | 34.6 |
FADE | R50 | 39.1 | 35.1 |
SAPA-B | R50 | 38.7 | 35.1 |
DySample-S+ | R50 | 39.3 | 35.5 |
DySample+ | R50 | 39.6 | 35.7 |
FreqFusion (Ours) | R50 | 40.0 | 36.0 |
Nearest | R101 | 40.0 | 36.0 |
DySample+ | R101 | 41.0 | 36.8 |
FreqFusion (Ours) | R101 | 41.6 | 37.4 |
Checkpoint:
Mask R-CNN | Backbone | Mask AP | Configs | Links |
---|---|---|---|---|
FreqFusion | ResNet-50 | 36.0 | config | ckpt (code: PAMI) |
FreqFusion | ResNet-101 | 37.3 | config | ckpt (code: PAMI) |
Panoptic FPN | Backbone | Params (M) | PQ | PQth | PQst | SQ | RQ |
---|---|---|---|---|---|---|---|
Nearest | R50 | 46.0 | 40.2 | 47.8 | 28.9 | 77.8 | 49.3 |
Deconv | R50 | +1.8 | 39.6 | 47.0 | 28.4 | 77.1 | 48.5 |
PixelShuffle | R50 | +7.1 | 40.0 | 47.4 | 28.8 | 77.1 | 49.1 |
CARAFE | R50 | +0.2 | 40.8 | 47.7 | 30.4 | 78.2 | 50.0 |
IndexNet | R50 | +6.3 | 40.2 | 47.6 | 28.9 | 77.1 | 49.3 |
A2U | R50 | +29.2K | 40.1 | 47.6 | 28.7 | 77.3 | 48.0 |
FADE | R50 | +0.1 | 40.9 | 48.0 | 30.3 | 78.1 | 50.1 |
SAPA-B | R50 | +0.1 | 40.6 | 47.7 | 29.8 | 78.0 | 49.6 |
DySample-S+ | R50 | +6.2K | 41.1 | 48.1 | 30.5 | 78.2 | 50.2 |
DySample+ | R50 | +49.2K | 41.5 | 48.5 | 30.8 | 78.3 | 50.7 |
FreqFusion (Ours) | R50 | +0.3 | 42.7 | 49.3 | 32.7 | 79.0 | 51.9 |
Nearest | R101 | 65.0 | 42.2 | 50.1 | 30.3 | 78.3 | 51.4 |
DySample+ | R101 | +49.2K | 43.0 | 50.2 | 32.1 | 78.6 | 52.4 |
FreqFusion (Ours) | R101 | +0.3 | 44.0 | 50.8 | 33.7 | 79.4 | 53.4 |
Checkpoint:
Panoptic FPN | Backbone | PQ | Configs | Links |
---|---|---|---|---|
FreqFusion | ResNet-50 | 42.7 | config | ckpt (code: PAMI) |
FreqFusion | ResNet-101 | 44.1 (44.0 in paper) | config | ckpt (code: PAMI) |
Note:
Original code of mmdet can be found here.
For more details on installing and using mmdetection, please refer to the README file.
If you use our dataset or code for research, please cite this paper (early access now):
@ARTICLE{2024freqfusion,
author={Chen, Linwei and Fu, Ying and Gu, Lin and Yan, Chenggang and Harada, Tatsuya and Huang, Gao},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Frequency-aware Feature Fusion for Dense Image Prediction},
year={2024},
volume={1},
number={1},
pages={1-18},
doi={10.1109/TPAMI.2024.3449959}}
This code is built using mmsegmentation, Mask2Former, mmdetection libraries.
If you encounter any problems or bugs, please don't hesitate to contact me at chenlinwei@bit.edu.cn. To ensure effective assistance, please provide a brief self-introduction, including your name, affiliation, and position. If you would like more in-depth help, feel free to provide additional information such as your personal website link. I would be happy to discuss with you and offer support.