BGU-CS-VIL / WTConv

Wavelet Convolutions for Large Receptive Fields. ECCV 2024.
MIT License
205 stars 10 forks source link

About how to use #4

Closed zhaocan22 closed 3 months ago

zhaocan22 commented 3 months ago

I feel sorry to bother you. Would you mind to tell me how to use your WTConv in my model(I use Unet as my backbone).

shahaffind commented 3 months ago

No problem, if your backbone uses depth-wise convolutions you can switch them directly with WTConv (see wtconvnext.py line 103). If your backbone uses only dense convolutions, you can replace them with separable convs, for example:

conv = nn.Conv2d(in_c, out_c, kernel_size=3)

can be replaced with:

conv_dw = nn.Conv2d(in_c, in_c, kernel_size=3, groups=3)  # depth-wise 3x3
conv_pw = nn.Conv2d(in_c, out_c, ketnel_size=1)  # point-wise 1x1

The results should be quite similar to the original network (however, it should be more efficient).
Now you can switch the depthwise conv with WTConv

conv_dw = WTConv(in_c, in_c, kernel_size=3, wt_levels=3)  # WTConv with 3x3 kernel and 3 levels
conv_pw = nn.Conv2d(in_c, out_c, ketnel_size=1)  # point-wise 1x1
zhaocan22 commented 3 months ago

I really appreciate your kind help.

shahaffind commented 3 months ago

I'm glad I can help :)

zhaocan22 commented 3 months ago

Hi, I am very interested in your work. But, how conv = nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1 ) can be replaced by your WTConv. If I want to change the size of feature maps.

shahaffind commented 3 months ago

WTConv supports strides, you can use it. So instead of

conv = nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1)

You can use

conv = nn.Sequential(
    WTConv2d(in_c, in_c, kernel_size=3, stride=2, wt_levels=2),
    nn.Conv2d(in_c, out_c,, kernel_size=1)
)

(and modify wt_levels as you like)