JuliaImages / Images.jl

An image library for Julia
http://juliaimages.org/
Other
535 stars 141 forks source link

An efficient point/patch-wise "distance" #918

Closed johnnychen94 closed 2 years ago

johnnychen94 commented 4 years ago

This need arises when I'm implementing a faster WNNM[1] image denoiser(still WIP, 7x boost at the time of writing). I'm not sure how broadly this can be used, so I'd like to open an issue here to get some early feedback.

Background

After the nonlocal mean filters[2] and BM3D[3], it becomes a consensus that block matching similar patches into a group and doing denoise work at a patch-level is more performant than at the pixel-level. A typical block-matching denoising workflow is as follows:

# for simplicity, I didn't put border condition into this and following code snippets

function block_matching_denoiser(f, img)
    out = fill(0, axes(img))
    W = fill(0, axes(img))
    for p in CartesianIndices(img)
        patch_p = @view img[p-r:p+r]
        matched_patches = block_matching(img, patch_p; num_patches=80)

        # input: m*n*N
        # output: m*n
        patch_out = f(img, patch_p, matched_patches)

        view(out, matches_patches) .+= patch_out
        view(W, matches_patches) .+= 1
    end

    # weighted summation
    # typically called `patch2img` in MATLAB world
    out ./= W
    return out
end

The major difference between each algorithm lies in f where nonlocal mean uses a weighted mean, BM3D uses a sophisticated 1D+2D filter, and WNNM uses low-rank approximation (svd).

There are two computational bottlenecks of this algorithm: block_matching and f. In this issue, I'll just focus on block_matching. (I have ideas on how svd can be optimized for this very specific task but that's out of the scope here.)

Here's a demo of block-matching (from BM3D website)

A block-matching subroutine, in its naive form, is:

# `f(patch_p, patch_q)` measures how similar two patches are.

function block_matching(f, img, p; num_patches, patch_size, patch_stride, search_size)
    R = CartesianIndices(img)
    r = patch_size ÷ 2

    patch_p = @view img[p-r:p+r]
    _measure(q) = f(patch_p, view(img, q-r:q+r))
    qs = local_neighbor(p, patch_stride, search_size)
    dist = _measure.(qs) # the computation bottleneck

    matched_points = qs[sortperm(dist)[1:num_patches]]
    matched_patches = [q-r:q+r for q in matched_points]
end

Block matching for one pixel seems fine, but this is only a subroutine in the outer loop, so there is a lot of repeated computation involved. search_size is usually set a relatively small number for exactly this reason (to reduce computation and memory requirement). This issue is on strategies to remove this unnecessary computation in a memory-friendly way.

The existing implementation has tweaked in a way that can hardly be reused by other codes; deeply coupled with the for loop, which is not good for either performance (e.g., how to multiple-threads things will be a big challenge) or code reuse. This issue also tries to provide an easy-to-use interface with all optimization transparent to users.

Proposal

Naively, to represent the complete pointwise distance result of array A and B, it needs an array of axes (axes(A)..., axes(B)...). This is an extremely large array even for two moderate-size matrices and we can not create it directly.

What I have in mind is to create a new array type, which holds the following three properties:

With this in mind, the block matching denoising can be computed quite easy and efficient:

patch_size = (7, 7)
r_d = CartesianIndex((patch_size..., patch_size...) ÷ 2)

# here we share both abs2 and ssd results globally
point_distances = pointwise((x,y)->abs2(x-y), A)
patch_distances = pointwise(A) do p, q
    center = CartesianIndex(p.I..., q.I...)
    sum(view(point_distances, center-r_d:center+r_d))
end

# a block matching thus becomes a trivial partial sort on patch_distances

Here pointwise returns the array type that I want to add. Because this type mimics the naive 4d array concept, it is quite intuitive to use.

The design of this array type completely depends on how results are cached. There are two caching strategies that sound promising to me.

static window cache

This is exactly the array abstraction of the existing block matching code in those all algorithms; they only tries to find similar patches in a non-local neighborhood (not in the entire global image). Each pixel has a cache block of size patch_stride[k]*prod(window_size)/window_size[k] and that stores the results in its non-local neighborhood.

More details are needed on how indices are computed (it requires some mind effort to work it right) but this is the idea.

Benefit:

Drawback:

FIFO queueing cache

This simple caching strategy can be quite useful because, in many cases, we are iterating over the image sequentially. Whatever removed from the cache is very likely to never be used anymore in future iteration.

Benefit:

Drawback:

others

LRU and other sophisticated cache strategies might not apply to this task because of the additional overhead; each computation might just take hundreds of ns or several μs. I don't have a good estimation yet.

Plans

I plan to put this in a new package. I plan to implement the "static window cache" first as a faithful reimplementation just to see how fast I can get WNNM to.

References

[1] Gu, S., Zhang, L., Zuo, W., & Feng, X. (2014). Weighted nuclear norm minimization with application to image denoising. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 2862-2869).

[2] Buades, A., Coll, B., & Morel, J. M. (2005, June). A non-local algorithm for image denoising. In 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05) (Vol. 2, pp. 60-65). IEEE.

[3] Dabov, K., Foi, A., Katkovnik, V., & Egiazarian, K. (2007). Image denoising by sparse 3-D transform-domain collaborative filtering. IEEE Transactions on image processing, 16(8), 2080-2095.

johnnychen94 commented 4 years ago

A preview of this implementation can be found at: https://johnnychen94.github.io/LazyDistances.jl/dev/democards/examples/image%20processing/block_matching/

Efforts will be put to stabilize the API and add more block matching algorithms in.

timholy commented 4 years ago

FFTs (as used in BlockRegistration.jl) can eliminate the repeated computation, but of course they have a large constant factor (typically ~30). So if you're comparing more than 30 blocks, probably better to use the FFT.

BlockRegistration was designed mostly for spatial 3D images, and since 3^3 = 27, you crest over 30 candidates for anything more than a single-pixel shift. Hence it was a no brainer to implement block matching with the FFT.

Edit: the overhead for BlockRegistration is probably bigger than 30x. It uses a sophisticated strategy to eliminate boundary effects, and that requires multiple FFTs: https://github.com/HolyLab/RegisterMismatch.jl/blob/master/src/RegisterMismatch.jl, specifically https://github.com/HolyLab/RegisterMismatch.jl/blob/00b9eda886141739f18b34440a381af5bd602dd1/src/RegisterMismatch.jl#L238-L247. The basic idea is summarized in this long-unfinished draft manuscript (sitting on my hard drive since 2006 :scream:):

image

johnnychen94 commented 3 years ago

Curious to ask, how is "mismatch" defined and what's the criterion to get a "mismatch"? For "best match" block matching, for each pixel p in the image, it computes the patch distances f(patch_p, patch_q) and returns the pixel q with minimal distances.

image

If this is too complicated to explain, could you give one or more papers that I can refer to?

timholy commented 3 years ago

I may be misunderstanding your question, but the mismatch is just the mean-square-difference. Eq 8 can be interpreted as follows: the numerator is computing the difference between If (fixed image) and Im (moving image). The pixels in Im are moved, x → g(x) (a vector to a vector), before computing the difference. The novelty this adds compared to most other implementations is the θ stuff, which allows you to mask out certain pixels. In our BlockRegistration code, we marked bad pixels with NaN, and then before we use the images in this formula we do the equivalent of

θ = isnan.(img)
imgc = copy(img)
imgc[θ] .= 0
θ = (!).(θ)    # true for non-NaN pixels

When we are only considering translations, Eq 9 applies. The FFT gives us E for all possible x0 and we then just do argmin on the result.

johnnychen94 commented 3 years ago

So to verify if I understand the context, in direct mode, the mismatch(distances) here only calculates

x = p-r:p+r
map(CartesianIndices(Im)) do q
    gx = q-r:q+r
    f(If[x], Im[gx])
end

where f is a mask-aware mean-square difference, and the output is then passed to findmin to get the argmin?

timholy commented 3 years ago

That's exactly right. The only thing added by the extra steps of the derivation, which what lay the foundation for using the FFT, is to prevent this from being O(N*M) (N = # of pixels, M = number of shifts) and instead making it closer to O(N+M).