Open nkraicer opened 2 months ago
Hey @nkraicer, thanks for your interest! We didn't develop SpectralDiffuserCam, so I'm not completely familiar with their implementation.
How many channels does your PSF have?
The one place where I think RealFFTConvolve2D
would have to be changed is here, as the package expects either an RGB or grayscale PSF. This should be an easy fix of using psf.shape[-1]
instead of the hardcoded 3
. There are some other checks that need to be relaxed to accept multichannel different than 3 (here, here).
Once those are fixed for multichannel not having to be 3, I would start off trying the ADMM script with your PSF and measurement to see if you get an expected reconstruction.
Let me know if that's unclear or if I missed something!
@nkraicer I took a closer look at how the original authors pose the image recovery and it's not assuming a unique PSF per channel so I guess you also have a single channel PSF?
In that case RealFFTConvolve2D
doesn't have to be modified, but it's rather setting the appropriate forward and adjoint operators. Namely (multiplying with the mask and summing) like in their forward, and multiplying with the mask like in their adjoint. Then you would have to change how our FISTA computes the gradient to get to something like their grads.
So one option would be to make a new FISTAHyperSpectral
class that inherits from FISTA
by:
_grad
method like so:
def _grad(self):
# make sure to sum on correct axis, and apply mask on correct dimensions
diff = np.sum(self.mask * self._convolver.convolve(self._image_est), 2) - self._data # (H, W, 1)
return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels
You could also replace proj with one of their methods.
Hi, I used your old ver of SpectralDiffuserCam I want to use all of the provided reconstruction methods adding a mask Is it straight forward to somehow change your rfft_convovle.py with these from spectraldiffusercam?
def Hpower(self, x): x = np.fft.ifft2(self.H np.fft.fft2(np.expand_dims(x,-1), axes = (0,1)), axes = (0,1)) x = np.sum(self.mask self.crop(np.real(x)), 2) x = self.pad(x) return x
Any other major changes that should be taken into consideration? Thanks!