HolyLab / BlockRegistrationScheduler.jl

Multi-core image registration scheduler
1 stars 0 forks source link

Mismatch increases in this case #35

Open Cody-G opened 7 years ago

Cody-G commented 7 years ago

For some time I've been trying to get good registrations when there is inconsistency in sampling locations along the z-axis. This is common with high-speed OCPI recordings. We're doing a lot to handle this at the level of the microscope hardware and acquisition software, but it would be good if our registration algorithm could clean up the last bit of it. I wasn't satisfied with the results so i made a synthetic image volume and simulated the situation (script below). I'm seeing that the loss (aka mismatch) actually increases when registering these volumes. I didn't yet track down why that is, and this can probably be triggered by a simpler test case, but here's a demo that I know demonstrates the issue (note that it requires a utility package BinaryVolumes, another HolyLab repo.)

using BlockRegistration, BlockRegistrationScheduler, RegisterDeformation, RegisterMismatch
using AffineTransforms
using ImageView2, Images
using FixedSizeArrays
using JLD
using Interpolations

using BinaryVolumes

#utility function for running Apertures registration on a pair of images
function _reg_def(fixed, moving, knots, mxshift; lambda = 0.01)
    λ = lambda
    algorithm = Apertures[Apertures(fixed, knots, mxshift, λ, identity; pid=myid(), correctbias=false) for i = 1:1]
    mon = monitor(algorithm, (), Dict{Symbol,Any}(:u=>ArrayDecl(Array{Vec{3,Float64},3}, ([convert(Int,x.len) for x in knots]...))))
    mon[1][:mismatch] = 0.0
    fileout = tempname()
    @time driver(fileout, algorithm, moving, mon)
    u = JLD.load(fileout, "u")
    mm = JLD.load(fileout, "mismatch")
    return u, mm
end

#add jitter in z sampling location, simulating inconsistencies in piezo position when using OCPI under certain conditions
function z_jitter{T}(img::Array{T,3}, npix::Float64)
    @assert npix < 0.5 && npix >= 0.0
    etp = extrapolate(interpolate(img, BSpline(Linear()),OnGrid()), Flat())
    out = zeros(eltype(img), size(img))
    z_def = Float64[]
    for i in 1:size(img,3)
        r = (rand()*2*npix)-npix
        push!(z_def, r)
        for y = 1:size(img,2)
            for x = 1:size(img,1)
                out[x,y,i] = etp[x,y,i+r] #Interpolations doesn't currently support slicing
            end
        end
    end
    return out, z_def
end

#create template ellipsoid volume
radii = [10;20;15]*2
num_pad = 20
line_spacings = [6;6;6];

vmask = ellipsoid_volume(radii; fill_val=true, num_pad = num_pad);
vbound = find_boundaries(vmask; val=true);

#draw grid lines. Make density higher on one side so that there's no
#ambiguity in orientation.  Only need to do this for n-1 dimensions
sz = [size(vmask)...];
vgrid = deepcopy(vbound);
for i = 1:length(line_spacings)
    line_idxs_a = floor(Int,linspace(1,sz[i]/2, floor(Int,sz[i]/line_spacings[i])));
    line_idxs_b = floor(Int,linspace(sz[i]/2,sz[i], floor(Int, sz[i]/(1.5*line_spacings[i]))))[2:end];
    cond = x->x==true
    vgrid = draw_lines!(vgrid, vmask, i, line_idxs_a, cond; fill_val=true);
end
vgrid = map(Float64, vgrid);
fixed = vgrid

moving, z_def = z_jitter(vgrid, 0.49); #plus or minus a half frame of jitter

mxshift = (3,3,3)
gridsize = (2,2,2) 
knots = map(d->linspace(1,size(fixed,d),gridsize[d]), (1:ndims(fixed)...))

u, mm = _reg_def(fixed,moving,knots,mxshift;lambda=0.0);
d = griddeformations(u, knots)
warped = warp(moving, d[1])

#mismatch is worse after registration
imshow(colorview(RGB, moving, zeroarray, fixed))
imshow(colorview(RGB, warped, zeroarray, fixed))
@show ratio(mismatch0(moving,fixed), NaN) 
@show mm
Cody-G commented 7 years ago

Note that the mismatch does decrease when the grid size is made large in the z dimension (as it should be if one wants to correct this kind of jitter). I just discovered this by mistake, and it's concerning to me that it's choosing a solution that increases the mismatch.

Even with a full-sized grid spacing in z it has trouble finding the right deformation, but it's not too far off. I'm looking into improving that now.

Cody-G commented 7 years ago

It's pretty complicated to understand what's happening in that test case, so I've been looking for simpler ways to trigger strange behavior. I've found some surprising cases with 1-D registration. Again I'm creating the moving image by jittering the sampling locations in the one dimension. I find that

Here's an example script. If you run it a few times you'll notice that my guess is often better than the algorithm's, and the monitored mismatch is almost always negative. Could this be a problem with the way we're interpolating the mismatch?

using BlockRegistration, BlockRegistrationScheduler, RegisterDeformation, RegisterMismatch
using ImageView2, Images
using FixedSizeArrays
using JLD
using Interpolations

#utility function for running Apertures registration on a pair of images
function _reg_def(fixed, moving, knots, mxshift; lambda = 0.01)
    λ = lambda
    algorithm = Apertures[Apertures(fixed, knots, mxshift, λ, identity; pid=myid(), correctbias=false) for i = 1:1]
    mon = monitor(algorithm, (), Dict{Symbol,Any}(:u=>ArrayDecl(Array{Vec{1,Float64},1}, ([convert(Int,x.len) for x in knots]...))))
    mon[1][:mismatch] = 0.0
    fileout = tempname()
    @time driver(fileout, algorithm, moving, mon)
    u = JLD.load(fileout, "u")
    mm = JLD.load(fileout, "mismatch")
    return u, mm
end

#add jitter in sampling location, simulating inconsistencies in piezo position when using OCPI under certain conditions
function jitter{T}(img::Array{T,1}, npix::Float64)
    @assert npix < 0.5 && npix >= 0.0 #don't tear space
    etp = extrapolate(interpolate(img, BSpline(Linear()),OnGrid()), Flat())
    out = zeros(eltype(img), size(img))
    z_def = Float64[]
    for i in 1:length(img)
        r = (rand()*2*npix)-npix
        push!(z_def, r)
        out[i] = etp[i+r]
    end
    return out, z_def
end

fixed = zeros(6)
fixed[3:4] = 1.0

moving, z_def = jitter(fixed, 0.45);

mxshift = (1)
gridsize = (6) 
knots = map(d->linspace(1,size(fixed,d),gridsize[d]), (1:ndims(fixed)...))
u, mm = _reg_def(fixed,moving,knots,mxshift;lambda=0.0);
@show mm #the monitored mismatch is negatve.  red flag?
d = griddeformations(u, knots)
warped = warp(moving, d[1])

ideal_u = fill(0.0, size(u))
ideal_u[1,2,1] = -0.9
ideal_u[1,5,1] = 0.9

ideal_d = griddeformations(ideal_u, knots)
ideal_warped = warp(moving, ideal_d[1]);
@show ratio(mismatch0(moving,fixed), NaN)
@show ratio(mismatch0(warped,fixed), NaN) #this time it's not negative
@show ratio(mismatch0(ideal_warped,fixed), NaN) #our guessed answer is usually better
timholy commented 7 years ago

Thanks for the very helpful issue reports. I've started with the last report, the 1d example. I can trace the negative mismatch to a simple source: In your example, the mismatch at certain grid points looks like

z = NumDenom(0.0,1)
o = NumDenom(1.0,1)
mm = CenterIndexedArray([z, z, o])

and we can reproduce the negative penalty with a simple analog:

julia> using Interpolations

julia> penalty = [0.0, 0.0, 1.0]
3-element Array{Float64,1}:                                                                                                                                                 
 0.0                                                                                                                                                                                         
 0.0                                                                                                                                                                                         
 1.0                                                                                                                                                                                         

julia> itp = interpolate!(penalty, BSpline(Quadratic(InPlace())), OnCell())
3-element Interpolations.BSplineInterpolation{Float64,1,Array{Float64,1},Interpolations.BSpline{Interpolations.Quadratic{Interpolations.InPlace}},Interpolations.OnCell,0}:                  
  0.0                                                                                                                                                                                        
 -2.77556e-17
  1.0        

julia> itp[1.51]
-0.08792000000000001

You get the same thing from

itp = CenterIndexedArray(interpolate!([z, z, o], BSpline(Quadratic(InPlace())), OnCell()))
RegisterPenalty.vecindex(itp, Vec(-0.49))

In a certain sense, the returned value is correct: if you interpolate a quadratic through those three points, the minimum will occur between the first and second point, and the value will be negative.

However, it's also not hard to understand how this could be problematic. Linear interpolation wouldn't suffer from the same problem, but the issue there is that (for efficiency) we really need the gradient of the penalty during optimization, and you don't have a smooth gradient unless you use at least quadratic interpolation.

Bottom line, I'm not quite sure what to do about this. I almost wonder if we should switch to a completely different optimization algorithm and focus only on integer-valued shifts.

timholy commented 7 years ago

One option might be to do linear interpolation and use a subgradient method. I might be tempted to move so that the maximum step is 0.1 pixel for any grid point, and just stop when the resulting penalty increases. I'll try cooking something up along those lines.

Cody-G commented 7 years ago

That would make sense to me! I also thought we would have to keep the step size small if we went with something like that.

There's one other thing I should bring up before we proceed. I think there could be a fundamental problem with interpolating the mismatch. Consider fixed = [0.0; 0.0; 1.0; 1.0; 0.0], and a moving image slid half a pixel to the left moving = [0.0; 0.5; 1.0; 0.5; 0.0]. So the ideal shift of the moving image would be 0.5. If we evaluate the mismatch at integer offsets we get equal values at offsets of 0 and 1. Moreover both of the values we get are kind of mediocre ~compared to the perfect 0 we could get if we evaluated the mismatch directly at 0.5 (assuming linear interpolation).~ Instead we get a range of equally-scored solutions from 1 to 0. Edit: Well actually if we continue to use quadratic interpolation I guess they aren't equal, but I don't see a guarantee that the quadratic would help and not hurt.

So my question is, how hard would it be to move the interpolation upward into the mismatch calculation itself? I think I don't understand Fourier methods well enough to know if that makes sense. Alternatively I think it may make sense to upsample the images before running registration, though that may get expensive...

timholy commented 7 years ago

Keep in mind that a linear interpolation of [0.0, 0.5 0.1, 0.5, 0.0] at half-grid points is [0.25, 0.75, 0.75, 0.25]. There is no way to reconstruct the original image by interpolation.

Cody-G commented 7 years ago

Oh yes you are right! So we wouldn't get a perfect answer like I said. But we would still get a lower mismatch value than at the integer locations.

timholy commented 7 years ago

True. Interestingly, for that example, quadratic interpolation would suggest the right location for the minimum mismatch---the mismatch is symmetric around a half-grid point, so quadratic interpolation would place the minimum at 0.5.

Cody-G commented 7 years ago

But it seems to me like the quadratic may be hurting us as often as it helps. I ran registration just to verify that it finds the right minimum in that case, and surprisingly it doesn't:

using BlockRegistration, BlockRegistrationScheduler, RegisterDeformation, RegisterMismatch
using FixedSizeArrays
using JLD 
using Interpolations

#utility function for running Apertures registration on a pair of images
function _reg_def(fixed, moving, knots, mxshift; lambda = 0.01)
    �� = lambda
    algorithm = Apertures[Apertures(fixed, knots, mxshift, ��, identity; pid=myid(), correctbias=false) for i = 1:1]
    mon = monitor(algorithm, (), Dict{Symbol,Any}(:u=>ArrayDecl(Array{Vec{1,Float64},1}, (1,))))
    mon[1][:mismatch] = 0.0 
    fileout = tempname()
    @time driver(fileout, algorithm, moving, mon)
    u = JLD.load(fileout, "u")
    mm = JLD.load(fileout, "mismatch")
    return u, mm
end
fixed = [0.0;0.0;1.0;1.0;0.0]
moving = [0.0;0.5;1.0;0.5;0.0]
mxshift = (1) 
gridsize = (1) 
knots = ([3],)
u, mm = _reg_def(fixed,moving,knots,mxshift;lambda=0.0);
@show u   # == -0.336435
Cody-G commented 7 years ago

Sorry the last line should have been @show u, I updated the code to reflect that.

timholy commented 7 years ago

One issue I noted is that mxshift = (1,) will in fact confine the maximum deformation to 0.5 or smaller (it computes mismatch for shifts up to 1, but to evaluate the penalty you have to interpolate and for quadratic interpolation that requires that it be at least a half-step interior to the boundary). So your ideal_u is not quite a fair comparison.

But I have one or more fixes coming (still testing...).

timholy commented 7 years ago

Worth pointing out something else in this context. Setting lambda = 0 can be useful for debugging, and indeed this detected a bug in initial_deformation (fix coming). However, weird stuff can happen. For example, in the code I'm testing right now I got this:

julia> ϕ
RegisterDeformation.GridDeformation{Float64,1,Array{FixedSizeArrays.Vec{1,Float64},1},LinSpace{Float64}}(6-elementArray{FixedSizeArrays.Vec{1,Float64},1}:
 Vec(0.0)
 Vec(-1.0)
 Vec(0.0)
 Vec(-1.0)
 Vec(1.0)
 Vec(0.0)
,(linspace(1.0,6.0,6),))

julia> warped
6-element Array{Float64,1}:
   0.0     
 NaN       
   0.898912
   0.898912
   0.0     
   0.0     

Notice the NaN in the middle of the image. That's because the second pixel in the warped image didn't "receive" any intensity from the original image. Using a small lambda fixes that problem.

Cody-G commented 7 years ago

Oops accidentally closed the issue via tab+enter

Using a small lambda fixes that problem.

Ah interesting, that is good to know. One thing I haven't quite figured out is how to decide what is a small lambda for a given grid size. It seems to me that the two parameters are linked in a complex way. i.e. a lambda = 0.01 may allow little deformation on a coarse grid but significant deformation on a finer grid. Do you have any tips about how to scale lambda with the grid size? Maybe there is a normalization I could do so that I get an affine penalty on a per-cell basis rather than for an entire image? Could it be as simple as dividing lambda by prod(gridsize)? (assuming all cells have the same volume).

timholy commented 7 years ago

Do you have any tips about how to scale lambda with the grid size?

I'd have to go through the code to remind myself what's being computed, but here's a guess. Our grid is a discrete approximation of a continuous penalty that looks something like

λ \int dx (D(x)-A(x))^2

where D is the deformation and A is the best-fit affine approximation of D. Converting this to a sum, we get

\sum (λ dx) (D(x)-A(x))^2

which implies is that λ dx should be approximately constant. This seems to suggest that λ should scale like prod(gridsize), since dx scales like 1/prod(gridsize).

That said, auto_λ is a good way to test this.

timholy commented 7 years ago

Given https://github.com/HolyLab/BlockRegistration/pull/46, I should post a script for testing things independently of BlockRegistrationScheduler:

function reg_lowlevel(fixed, moving, knots, mxshift, λ = 0.01, order=Quadratic)
    cs = coords_spatial(fixed)
    gridsize = map(length, knots)
    aperture_centers = aperture_grid(size(fixed, cs...), gridsize)
    aperture_width = default_aperture_width(fixed, gridsize, zeros(Int, ndims(fixed)))
    mms = mismatch_apertures(fixed, moving, aperture_centers, aperture_width, mxshift; normalization=:pixels)
    E0 = zeros(size(mms))
    cs = Array(Any, size(mms))
    Qs = Array(Any, size(mms))
    thresh = length(fixed)/prod(gridsize)/4
    for i = 1:length(mms)
        E0[i], cs[i], Qs[i] = qfit(mms[i], thresh; opt=false)
    end
    global mmc = deepcopy(mms)  # for debugging
    mmis = interpolate_mm!(mms, order)
    ap = AffinePenalty{Float64,ndims(fixed)}(knots, λ)
    ϕ, mismatch = RegisterOptimize.fixed_λ(cs, Qs, knots, ap, mmis)
end

and you can replace a couple of lines of your script (the one that calls _reg_def) with

ϕ, mm = reg_lowlevel(fixed, moving, knots, mxshift, 0.001, Linear)
warped = warp(moving, ϕ)
timholy commented 7 years ago

On your first script, if I allow 101 grid points along z, then I find that the best setting is something like

ϕ, mm = reg_lowlevel(fixed, moving, knots, mxshift, 10, Quadratic);

Using lower λ makes the fit worse. Unforutnately, it's still worse than the original:

julia> @show ratio(mismatch0(warped,fixed), NaN)
ratio(mismatch0(warped,fixed),NaN) = 0.0023416889785358806
0.0023416889785358806

julia> @show ratio(mismatch0(moving,fixed), NaN)
ratio(mismatch0(moving,fixed),NaN) = 0.0016513571370021908
0.0016513571370021908

Keep in mind that our registration does not try to align the images; all the registration is performed on the (block-computed) mismatch data. So a parameter setting that decreases the interpolated mismatch may not decrease the mismatch with the warped image.

This is obviously a weakness of our algorithm, but it's the key to its fast performance. If you have to warp and then compute the mismatch for every possible change in the deformation, that's going to be a big cost. We try to make things more efficient by predicting the impact that change will have by interpolation, but obviously it doesn't always work very well.

I wonder if we should create a completely different algorithm for handling just the z component? It's much faster to interpolate along just one axis, which you may know you can do like this:

julia> A = rand(5, 5, 7);

julia> itp = interpolate(A, (NoInterp(), NoInterp(), BSpline(Quadratic(Flat()))), OnCell());

This would involve 3 rather than 27 evaluation points per output point.

timholy commented 7 years ago

It occurs to me that there may be a solution that allows us to have our cake and eat it too. The mismatch calculation looks like sum((fixed[x] - moving[ϕ[x]])^2), and we've figured out good ways to do that in "blocks" to exploit the FFT (for efficiency). We then estimate the effect of sub-pixel shifts by interpolating the mismatch value for integer shifts---effectively, this estimates the mismatch gradient by finite differencing. As an alternative approach, conceivably we could also calculate the gradient of this quantity with respect to ϕ in "block" form (i.e., exploiting the FFT). Then we'd have to do the interpolation so that we match both the value and gradient at each point.

It seems possible that it would afford higher accuracy, because we'd have the actual gradient of the mismatch (at integer-pixel positions) rather than an estimate of the gradient calculated by finite differencing.

However, before plunging in, considerable caution seems warranted. Implementing this would be a lot of work; not only would we have to implement the calculation of the gradient for integer shifts, but we'd have to develop the interpolation scheme from the math on up (I'm not aware of any such interpolation scheme in any language). I'd want a clear demonstration (in a simple situation, e.g., 1d) that the improvement in estimating the mismatch gradient was worth all the effort.

An alternative is to add a stage at the end that iteratively warps and tests the result to achieve subpixel accuracy.

timholy commented 7 years ago

It's worth pointing out that the scheme above still wouldn't accurately model the change in the deformation between grid points: any time you use the FFT, you're forced to hold the shift constant over each block. In contrast, when we go to warp the image, we interpolate the deformation. Consequently the above scheme still wouldn't represent a fully-accurate model of how the mismatch will depend on the deformation.

timholy commented 7 years ago

See https://github.com/HolyLab/BlockRegistration/pull/48

Cody-G commented 7 years ago

Sorry for the delay in answering this!

This seems to suggest that λ should scale like prod(gridsize), since dx scales like 1/prod(gridsize). That said, auto_λ is a good way to test this.

Thanks, I will do that!

I wonder if we should create a completely different algorithm for handling just the z component? It's much faster to interpolate along just one axis, which you may know you can do like this:

I was thinking about doing something like this as a preprocessing step. I had imagined just finding the linear combination of adjacent moving slices that matches each fixed slice the best. Do you think we will gain something from quadratic interpolation in that case? I can try it both ways.

As an alternative approach, conceivably we could also calculate the gradient of this quantity with respect to ϕ in "block" form (i.e., exploiting the FFT). Then we'd have to do the interpolation so that we match both the value and gradient at each point

I think this may be akin to what I (vaguely) suggested, to "move the interpolation upward into the mismatch calculation". I do recognize it's a lot of work, and I'm not sure I would trust myself to handle the FFT gradients properly without spending a long time. I expect that we could get better sub-pixel registration if we could pull it off. I can start by testing the iterative warping approach. If that improves accuracy a lot then maybe we can consider implementing this to improve performance (though I realize that this could still have accuracy shortcomings compared to iterative warping for the reason you mentioned).

I'll be testing this stuff out today and tomorrow, thanks for the improvements!

Cody-G commented 7 years ago

In addition to the new work underway in BlockRegistration I have found an approach that I like for (mostly) correcting the z-jitter problem very quickly. I can imagine doing this as a preprocessing step before registering each image with our more general algorithms.

It's easiest to explain if you take the first stack as the fixed stack. Stack 2 is then the first "moving" stack. We perform a correction on stack two by examining pairs of adjacent moving slices in combination with the fixed slice. The idea is that each fixed slice was sampled "between" a pair of moving slices in physical space, so we can find the interpolated slice in the moving image that best matches each fixed slice. The set of best-matching interpolated slices then replaces the original moving image. We repeat the same procedure with stack 3, but using the corrected stack 2 as the fixed image. This is dirt simple really, and the optimal (linear) interpolation coefficient for each stack can be found analytically. (Before I realized the analytical solution was so easy I used an iterative approach with Optim, and it actually failed to find the optimal solution. I never tracked down why.)

There is one caveat: if a fixed slice doesn't fall "between" a pair of moving slices then it fails. This can actually happen when the fixed slice happens to be sampled in a plane with a lot of peaks in intensity relative to surrounding space. In order to address this I also solve the inverse problem: solve again for the coefficients when the fixed and moving roles are swapped. When doing this we end up with complementary estimates of the displacement, and we can weight each one by the mismatch that it achieves. This can be viewed as a form or regularization, and in extreme cases the solution we get by combining both the forward and inverse solution can actually increase the mismatch even though the displacements it finds are correct. There are more details in the comments of my gist.

I can clean up the gist, write tests, and push it to Github if you don't see any fundamental problems with the approach. In practice it's improving my images a lot by eye, and it's nice that it's so fast (about 1 second per stack to find the coefficients). Actually currently it takes 3x longer to resample the moving image than it does to find the solution. I'm not sure why that is; the resampling function looks fine via @code_warntype. The main difference is that I use Interpolations for resampling but don't use it to find my solution. I can actually do everything without Interpolations but I'd like to understand whether I'm misusing it somehow. If you have any feedback on how to improve performance that would be great!

Here's the gist:

https://gist.github.com/Cody-G/c36831daf2db7c469713db9507a05aea

timholy commented 7 years ago

This is awesome. I love it because it's so directly targeted at the specific problem that is causing you the most trouble. I definitely see this as an excellent "preprocessing" step. If it obviates the need for https://github.com/HolyLab/BlockRegistration/pull/48 so much the better, although I won't be surprised if in the end we want both. But I'd suggest going piecemeal, and this sounds like the best piece to add first.

Before you turn it into a PR, here's one thing I noticed. Let M be moving stacked reshaped to be two-dimensional, collapsing the first two dimensions but keeping the z axis separate: M = reshape(moving, size(moving, 1)*size(moving, 2), size(moving, 3)). Let F be the same thing for the fixed image. Then the condition that a linear combination of M slices can make F is simply M*W = F. Consequently W = M\F. The prediction would be that, for each slice, W is significantly nonzero only for two entries. If it proved necessary, you could even use our NNMF HALS code from the cell segmentation to constrain W to be nonnegative.

This might be considerably slower than your approach, because you only consider pairs, but it might be at least worth playing with. Those BLAS routines are fast.

Cody-G commented 7 years ago

Formulating it as a single matrix equation is attractive. In that case adding a non-negativity constraint may be important, as I've seen negative solutions pop out with my pairwise approach. I can easily throw those solutions away now, so it would be nice to keep that ability. If we go with NMF are we guaranteed to find the best solution? I didn't think that was guaranteed with our HALS approach, but I'm not sure.

Probably the more important concern is how to constrain entries of the weight matrix to be less than one. If an entry is greater than one it can no longer be interpreted as an interpolation (more like an extrapolation). Again I can deal with it now by just throwing away infeasible solutions. The only way I know how to solve the full matrix method is with a Lagrange multiplier for each slice, though I think there are a lot of methods I'm not familiar with. One could also argue that it's okay to let the coefficients exceed one to respect the reality that some slices could have a longer dwell time over the same region than others, but I don't think that difference is very significant. Even if it were I think we would also capture it in the pairwise "inverse" solution.

I'll play around with this. The data may already constrain the problem enough that my concerns are for nothing...

timholy commented 7 years ago

If we go with NMF are we guaranteed to find the best solution? I didn't think that was guaranteed with our HALS approach, but I'm not sure.

General NMF is not convex so there is no guarantee of getting the best answer, no matter what algorithm you use. However in this case it's a simpler problem, since one of the two factors is already specified, so this is really a case of nonnegative least squares, which is convex (https://en.wikipedia.org/wiki/Non-negative_least_squares). So you're guaranteed to find the optimum. I suspect that's true even if you add an upper-bound constraint, but I'm not sure. That's certainly true of the single-factor, single-coordinate updates that HALS is based on. Perhaps the most accessible treatment I know of is http://cmp.felk.cvut.cz/ftp/articles/franc/Franc-TR-2005-06.pdf, with the key point being covered in section 4.

However, I also agree that the unconstrained problem might be good enough, if we're careful in how we use the solution.

Cody-G commented 7 years ago

Here's an interesting view of the full matrix solution. I zoomed in on the diagonal.

blas_result

Notice that the dark spots are negative. That was a synthetic dataset. Here is the result from a pair of real stacks:

blas_result_real

I think it's actually pretty neat, and it smells a bit like deconvolution! I'm tempted to use all of the information around the diagonal, but I haven't fully convinced myself that this is fair in all circumstances. What do you think?

In terms of performance the full solution takes about 5x as long as my pairwise solution. If I use the linear algebra solver to enact the pairwise solution (i.e. moving[:,3:4] \ fixed[:,3]) then the time is about the same as my method -- this is also cool that I was able to match BLAS performance with regular julia code.

I'm thinking the only reason to change to BLAS is if we want to use the full, unconstrained matrix. Like I said, it's really tempting. I will see if I find any obvious problems in the resulting images.

Cody-G commented 7 years ago

Hmm I have a linear algebra question. I solved for W with moving\fixed. I assumed that if moving * W = fixed then moving = fixed * W^-1 so I can also reconstruct the moving image from the fixed image if I multiply by the inverse of W. However inv(W) is not giving me a good inverse: If I multiply W * inv(W) I get something very far from the identity matrix:

bad_inverse

So of course when I try to use the inverse matrix it does a terrible job of reconstructing the ~fixed~ moving image. Maybe my W happens to be non-invertible but Julia is trying anyway? If so it would be nice if Julia gave a warning or something...

Cody-G commented 7 years ago

Oh wow, if I convert W to Float64 before taking the inverse I get something much better:

good_inverse

Maybe the roundoff error gets out-of-hand for Float32? Should this be considered a Julia bug?

Cody-G commented 7 years ago

...and if I go back and solve for W on real data again using 64-bit precision then the solution looks quite different (compare this to the most recent matrix image I posted)

real_result_f64

It's a bit scary that the results are so different! I'll be more careful about using Float32s from now on.

Unfortunately even with 64-bit precision the inverse matrix doesn't do a good job of reconstructing ~fixed~ moving. Do you understand why?

timholy commented 7 years ago

It's a bit scary that the results are so different! I'll be more careful about using Float32s from now on.

It's possible there's something ill-conditioned. Worth considering MF = qrfact(M) and looking at MF[:R]. AnythingAny diagonal elements close to zero? If you have black images at the top and bottom of your stack, that could spell trouble---those will give you nearly-zero R diagonals (because there's no useful information), and that could mess up everything else. But the good news is that you might be able to do a manual inversion (i.e., do M\F more carefully), something like this (warning, untested):

Q, R, p = qr(M, Val{true})  # pivoting should put your least-informative frames last
# For any diagonal elements of R that are almost zero, or even much smaller than the rest,
# set them to **Inf** here
Wp = R \ (Q' * F[:, p])
pinv = invperm(p)
W = Wp[pinv, pinv]  # not sure I've got this right...

Unfortunately even with 64-bit precision the inverse matrix doesn't do a good job of reconstructing moving

What does F = svdfact(W); F[:S] give you?

timholy commented 7 years ago

It might be a little slower, but perhaps it would be easiest to do all this with the svd:

USV = svdfact(M)
S = USV[:S]
thresh = sqrt(eps(eltype(S))) * maximum(S)
S[S .< thresh] = Inf
W = USV \ F

That way you don't have to deal with pivoting.

Cody-G commented 7 years ago

I tried all of your suggestions for calculating W, and I'm still getting a poor reconstruction of the moving image. The fixed image reconstruction looks great (i.e. M*W), but F*inv(W) looks bad. Maybe I'm making some very basic mistake, so I'll just share with you the code I'm using:

m64 = sqrt.(Float64.(moving)) #convert from UInt16, square root transform
f64 = sqrt.(Float64.(fixed))

szxy = prod(size(m64)[1:2])
szz = size(m64,3)
M = reshape(m64, szxy, szz)
F = reshape(f64, szxy, szz)
W = M\F 
F_reconstr = M*W #looks good
W_reconstr = F*inv(W) #very noisy image, can barely see anything

Worth considering MF = qrfact(M) and looking at MF[:R]. Any diagonal elements close to zero?

Not very close. Here's a plot of the (sorted) entries in the diagonal. There is one outlier, could that be a problem?

r_diagonal1

But the good news is that you might be able to do a manual inversion (i.e., do M\F more carefully), something like this (warning, untested):

This does work for getting me a W that reconstructs fixed, so I think you got it right. But again using the inverse doesn't work.

What does F = svdfact(W); F[:S] give you?

Another scatter of the entries:

f_s

perhaps it would be easiest to do all this with the svd

Same story as other methods for calculating W

timholy commented 7 years ago

Can't really tell without log scaling on the y-axis, but some of those last singular values look extremely close to 0 (the last one in particular). If so, it's not at all surprising you're getting bad reconstruction; if some of the singular values are close to 0, you can never take inv(W) at face value.

Just use the same trick again:

Wfact = svdfact(W)
S = Wfact[:S]
thresh = sqrt(eps(eltype(S))) * maximum(S)
S[S .< thresh] = Inf
Mrecon = Wfact \ F

If you use this trick wherever you need it, I bet you'll be fine with Float32 again, too.

timholy commented 7 years ago

I can't try your specific example because I don't have fixed and moving.

Cody-G commented 7 years ago

Your last line, Mrecon = Wfact \ F isn't working for me because dimensions don't match. I can transpose F to get the dimensions right. In that case I think I should also transpose Wfact? Anyway I haven't yet found an expression that works. I think I could also accomplish what I want by just solving for two different weight matrices: M \ F and F \ M. I just wanted to double check that I understood the math, but I guess the nearly-zero singular values are complicating things. I've copied this fixed and moving image to /usr/lab/share/cody_to_tim as fixed.nrrd and moving.nrrd if you get time to take a look.

I also started playing with the idea of removing jitter in all three dimensions simultaneously using this approach (we also get sub-pixel high frequency X-Y jitter from vibrations and the fish's heartbeat). I haven't yet figured out how to set up the linear algebra to solve for three dimensional recombination of slices. I started reading up a bit about tensors...does this sound like a useful direction to you? I'm thinking these sub-pixel corrections could make a big difference with segmentation.

Cody-G commented 7 years ago

Just to expand on the last thing I said about removing jitter in multiple dimensions: I think for the 2D case we would need to solve for two weight matrices W and X in an equation like this.

(MX)'W == F

That's one equation with two unknowns, which seems underconstrained. Intuitively it makes sense that it's undersconstrained because you can imagine a whole family of solutions that differ only by a scaling factor of X and W (i.e. if you think of the two recombinations as occurring sequentially, each recombination is a translation and an intensity scaling. The intensity scaling is free to vary as long as the combined scaling is correct). If that's the only family of solutions, then I think the solution space is just a line and we would be happy with any solution on the line. I'm not sure how to modify the equation to exploit that. I may be way off base here, but it seems like there should be a way to formulate this problem for 2D and 3D...

I also noticed that any of the recombinations we are interested in can be reduced to one-dimensional under a coordinate transformation. But I'm not sure it's even possible to mix coordinate transformations with our current recombination approach--I haven't yet found a way.

I could always just create a loss function and pass off the problem to Optim, but I would think there is a more elegant and reliable way to do this directly with linear algebra.

timholy commented 7 years ago

That's a really interesting suggestion. I've spent some time playing with this in this gist. Because I found that the unconstrained problem settled on really whacky values for the "interpolation coefficients," I implemented it so that the sum-of-absolute-values of the "interpolation coefficients" is 1. (This is an L1 regularization, quite reminiscent of LASSO except that there isn't an unknown lagrange multiplier.)

On a couple of simple tests, I've gotten very good results for shifts along a single axis, but as soon as I shift along two axes simultaneously I get junk. I don't see a way out, currently. Pity, because it's a really creative idea!

Cody-G commented 7 years ago

Interesting. What if we give it a good initial guess? I think we could initialize the weight matrices with a simple transform that aligns the centroids of the two images. I think we could do that by shifting the two starting identity matrices by their respective coordinates of the centroid difference vector. I will give that a try after I figure out what's causing the gradient tests to fail in RegisterPixelwise

Cody-G commented 7 years ago

When I use a good initial guess with your example in the gist, the quality of the reconstruction actually goes down after the first iteration and then oscillates all over the place. I didn't study very carefully how you're doing the coordinate descent, but maybe we could try a different method?

timholy commented 7 years ago

Hmm, that's interesting. Theoretically it should result in guaranteed descent; if it's not doing that, it suggests a bug.

As for switching to a different method, I'm all ears. But I don't see a simple alternative, do you? Coordinatewise descent is, as far as I know, a state-of-the-art method for LASSO (and simple, too).

Cody-G commented 7 years ago

Okay I have good news! The less constrained method in your updated gist works! Also I've expanded it to 3D, and it also seems to work well there. I'll clean it up a bit and post a new gist. Some observations

timholy commented 7 years ago

That sounds fantastic! I'm really excited that this seems promising!

With regards to the limitations, I think it's more general (and therefore harder to solve) than issues specifically with coordinate descent. The algorithm that constrains only sum(x) technically isn't using coordinate descent, although ALS (alternating least squares) is essentially "generalized coordinate descent": in 2d, it tackles "half" of the parameters simultaneously. (It would literally be half if the image is square.) With regards to the specific limitations listed:

I suspect the biggest issue is that there are surely many minima and we're not guaranteed to find the best one. In 1d the problem is convex, but it's not in any higher dimensionality. Thus we are guaranteed to find the global minimum for a 1d problem, but there is no such guarantee for 2d or higher.

A couple of your points do make me throw out one caution. Appearance of the image is not always the best guide to success; in particular, this approach does seem to run some risk of falling into the CURT trap. At some point we'll have to take the output of this algorithm and turn it into something that's in the spirit of a deformation. One of the things I like about this is that it can come up with things that aren't strictly interpolations (e.g., it can exploit negative coefficients), but in the end I suspect we'll need to impose some kind of locality.

It would be really interesting to see what this algorithm does with a small-angle rotation. In particular it would be exciting if we can somehow make sense of the parameters to extract the angle(s). In contrast with the shift case, I can't easily picture what those matrices should look like. I even wonder if we'd need to consider multi-tensor solutions, i.e., instead of our current version which can be written as

image

I wonder if we might need to use

image Obviously we don't want to go there until we know we need to.

Cody-G commented 7 years ago

the first one isn't applicable because the objective function is infinitely differentiable (smooth)

I may have misidentified that as the issue, but our problem when doing alternating least squares feels analogous when you consider a test case like this:

fixed = [1.0 0;
           0 0]
moving = [0 0;
          0 1.0]          

If we think of the X and Y matrices as swapping rows and columns (I know they can do more than that, but in these simple test cases that's all we want them to do) then there's no single swap that can decrease the objective, only a combination of two swaps can do it. That feels like we are in some hard corner of a level set of the objective, though I may be visualizing it incorrectly. If we were optimizing both matrices simultaneously we would not be at a local minimum. Rather, we are stuck at a matrix-wise local minimum.

I also agree that it's very difficult to extract something that's in the spirit of a deformation, especially if we want to include rotations or, most generally, affine transforms.

Here's one proposal that I think may be able to solve the original problem I wanted to solve, which was correcting for small differences between adjacent stacks in a timeseries (probably never more than 1 pixel of motion):

  1. If necessary, initialize our weight matrices by calculating the mismatch over a small maxshift using RegisterMismatch. This may not be necessary, and it could even make things worse if the shift is truly sub-pixel, so I would say we should look for the best integer shift (without interpolating the mismatch)
  2. Restrict the vast majority of the weight matrix entries to be zero. We just need to optimize the entries along a diagonal stripe centered on the initialization. The stripe should be at least 3 entries thick to allow interpolation using a center slice and its two adjacent slices. We can play with using a thicker stripe later if we're happy with the solutions that we're finding.
  3. Optimize all unconstrained matrix entries simultaneously. I'm not sure what method is best to use for this, maybe just something like BFGS?

Here's an updated version of your gist that includes my 3D implementation and some test cases demonstrating the points I've made.

https://gist.github.com/Cody-G/9d3c0e5c8a7fb20812234c1f8ad7457e

I'm doing the tensor operations very inefficiently, calling permutedims a lot. I thought about doing those operations with PermutedDimsArrays but I'm guessing this would break any optimizations that BLAS uses for cache efficiency? There's also a package https://github.com/Jutho/TensorOperations.jl that may offer better efficiency. I thought I would wait a bit to optimize for performance until we've decided on an algorithm.

timholy commented 7 years ago

That feels like we are in some hard corner of a level set of the objective, though I may be visualizing it incorrectly. If we were optimizing both matrices simultaneously we would not be at a local minimum. Rather, we are stuck at a matrix-wise local minimum.

I see your point. Certainly similar in spirit, if not exactly the same. It's essentially like a conventional deformation-based approach where you parametrize the deformation as a simple translation, and suppose you're aligning two images which are black (0) everywhere except for one bright blob. With conventional approaches, you move in the direction determined by the gradient, but if there's no overlap between the blobs, then the gradient at your starting point (no translation) is 0. So you might be able to reduce the mismatch, but the gradient doesn't help you figure out how to do it. This is part of the reason that our "forecasting" method aims to find the global minimum within the set of all allowable shifts.

timholy commented 7 years ago

I'm doing the tensor operations very inefficiently, calling permutedims a lot

It's quite possible this is the most efficient approach. I definitely wouldn't worry about it until you know it's a problem.

Cody-G commented 7 years ago

but if there's no overlap between the blobs, then the gradient at your starting point (no translation) is 0

I see the analogy, but it's worth emphasizing that the alternating method can fail even when there is image overlap, as in the case below. If we were optimizing both matrices simultaneously I think the gradient would actually help us find the right solution in this case:

fixed = [1.0 0 0;
         0.0 1 0;
         0.0 0 0]
moving = [0 0 0;
          0 1 0
          0 0 1.0]          

There's actually a lot of literature on solving similar problems. I think that PARAFAC might be the most similar to what we want to do, but I'm not sure. ALS seems like a popular algorithm but it's not the only one. I'm not sure how much time it's worth putting into this...I'll be back in the office soon, maybe we can discuss more in person.

http://epubs.siam.org/doi/pdf/10.1137/07070111X http://www.models.life.ku.dk/~rasmus/presentations/parafac_tutorial/paraf.htm https://en.wikipedia.org/wiki/Tensor_rank_decomposition

timholy commented 7 years ago

I did think of one potential concern with this approach, even for the case of 1d registration, although things we're already incorporating may solve the problem effectively. Let's suppose you're looking at a field of view in which there is no movement, but in one particular frame every neuron gets brighter simultaneously. Then the matrix might be something like 0.8*eye(n,n), i.e., it will make the intensities dimmer so as to match the fixed image.

This makes it seem all the more important to constrain the sum over columns to be 1 (which circumvents that problem).

Cody-G commented 7 years ago

I agree with your concern about global brightness changes. I just created a repository with my work on this so far. You can switch between three methods using a keyword argument: one is unconstrained, another is the sum-to-1 constraint, and another is the non-negative constraint. As we discussed, I'm thinking the nonnegative one may be helpful because with negativity it manages to find some creative (and physically impossible) ways to combine slices to cancel out shot noise. I do wonder whether we should just allow it to do this noise canceling. Images tend to get blurred a bit, but I don't think it's reducing the SNR (still need to quantify). In its current form the nonnegative method isn't working. I tried to use our nmf_B1! method from CellSegmentation, but I haven't been able to get it to work. I did try another NMF method from a public Julia package, and it does work but it's very slow. You can see both methods in the source code. I'll open an issue on the new repository with more info (https://github.com/HolyLab/RegisterInterp)

Cody-G commented 7 years ago

Here's the issue over there: https://github.com/HolyLab/RegisterInterp/issues/1