Official implementation of "ExpoMamba: Exploiting Frequency SSM Blocks for Efficient and Effective Image Enhancement", Accepted in ICML ES-Fomo-II 2024
fft_x = torch.fft.fft2(x)
amp = torch.real(fft_x)
pha = torch.imag(fft_x)
# Apply Dynamic Amplitude Scaling and Phase Continuity
amp_scaled = self.amplitude_scaling(amp)
pha_continuous = self.phase_continuity(pha)
# Processing with complex convolution
complex_input = torch.complex(amp_scaled, pha_continuous)
complex_processed = self.complex_conv(complex_input)
# Separate processed amplitude and phase
processed_amp = torch.real(complex_processed)
processed_pha = torch.imag(complex_processed)
# Process amplitude and phase with Mamba models
processed_amp = self.model_amp(amp_scaled)
processed_pha = self.model_pha(pha_continuous)
# Combine processed amplitude and phase, and apply inverse FFT
combined_fft = torch.complex(processed_amp, processed_pha)
output = torch.fft.ifft2(combined_fft).real
# Apply final smoothing convolution
output = self.smooth(output)
# Applying HDR processing after frequency modulation
x = self.hdr_layer(x)
return output
In # Process amplitude and phase with Mamba models , why _processed_amp = self.model_amp(ampscaled) but not _processed_amp = self.model_amp(processedamp)
def forward(self, x):
Compute FFT to get amplitude and phase
In # Process amplitude and phase with Mamba models , why _processed_amp = self.model_amp(ampscaled) but not _processed_amp = self.model_amp(processedamp)