YanchaoYang / FDA

Fourier Domain Adaptation for Semantic Segmentation
490 stars 79 forks source link

Error in the FDA implementation #40

Open LucasFidon opened 2 years ago

LucasFidon commented 2 years ago

Hi, thank you for sharing your code!

I have noticed two errors in your implementation of FDA in https://github.com/YanchaoYang/FDA/blob/b9a0cdf0bcc9b787c00e39df73eda5673706f219/utils/__init__.py#L11-L18

  1. Only the the top left and bottom left part of the image in the Fourier space needs to be changed. This is because you use rfft (and not fft) and, as a result, only half of the image in the Fourier domain is returned (due to the symmetry). Otherwise, you are replacing high frequency which is not what is described in the paper.

  2. The indices of the region mutated should be symmetrical around 0. Otherwise the output after inverse FFT will have a non-zero imaginary part (-b should be -b+1 above).

Also you may want to multiply b by 0.5 to get the max transform for L=1 instead of L=0.5 (see https://github.com/YanchaoYang/FDA/issues/36#issue-1031310669)

I propose to tackle those issues with this modified version fo your pytorch implementation:

def extract_ampl_phase(fft_im):
    # fft_im: size should be b x 3 x h x w
    fft_amp = torch.abs(fft_im)
    fft_pha = torch.angle(fft_im)
    return fft_amp, fft_pha

def low_freq_mutate(amp_src, amp_trg, L=0.1):
    _, _, h, w = amp_src.size()
    # multiply w by 2 because we have only half the space as rFFT is used
    w *= 2
    # multiply by 0.5 to have the maximum b for L=1 like in the paper
    b = (np.floor(0.5 * np.amin((h, w)) * L)).astype(int)     # get b
    if b > 0:
        # When rFFT is used only half of the space needs to be updated
        # because of the symmetry along the last dimension
        amp_src[:, :, 0:b, 0:b] = amp_trg[:, :, 0:b, 0:b]      # top left
        amp_src[:, :, h-b+1:h, 0:b] = amp_trg[:, :, h-b+1:h, 0:b]    # bottom left
    return amp_src

def FDA_source_to_target(src_img, trg_img, L=0.1):
    # get fft of both source and target
    fft_src = torch.fft.rfft2(src_img.clone(), dim=(-2, -1))
    fft_trg = torch.fft.rfft2(trg_img.clone(), dim=(-2, -1))

    # extract amplitude and phase of both ffts
    amp_src, pha_src = extract_ampl_phase(fft_src.clone())
    amp_trg, pha_trg = extract_ampl_phase(fft_trg.clone())

    # replace the low frequency amplitude part of source with that from target
    amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), L=L)

    # recompose fft of source
    real = torch.cos(pha_src.clone()) * amp_src_.clone()
    imag = torch.sin(pha_src.clone()) * amp_src_.clone()
    fft_src_ = torch.complex(real=real, imag=imag)

    # get the recomposed image: source content, target style
    _, _, imgH, imgW = src_img.size()
    src_in_trg = torch.fft.irfft2(fft_src_, dim=(-2, -1), s=[imgH, imgW])

    return src_in_trg

I have tested it with PyTorch v1.11.0.

Best wishes, Lucas

cong-code commented 9 months ago

Hi, if I want to extend to 3D image, how do I do that? Use torch.fft.rfftn instead?

xingshuojing commented 3 months ago

I think it should be calculated for all four corners because the rfft in the older version of pytorch used by the author and the argument states that oneside does not apply.