ComputationalCryoEM / ASPIRE-Python

Algorithms for Single Particle Reconstruction
http://spr.math.princeton.edu
GNU General Public License v3.0
46 stars 21 forks source link

`FBBasis2D` and `FLEBasis2D` correspondence #738

Open chris-langfield opened 1 year ago

chris-langfield commented 1 year ago

Preserving work done on this problem for posterity

FLEBasis2D will be a new basis class introduced in #693, ported from https://github.com/nmarshallf/fle_2d.

Differences between this basis and the existing FBBasis2D mean that either

Initial basis count heuristic

FBBasis2D sets the count in advance using the following formula:

self.count = self.k_max[0] + sum(2 * self.k_max[1:])

k_max[i] is the number of Bessel zeros (q's in the paper) included for Bessel order i. So this is equivalent to the DC component plus all k,q combinations times 2 (positive and negative).

FLEBasis2D uses

self.count = int(self.nres**2 * np.pi / 4)

which is based on the geometry of the problem (pi * R**2: the area of the disc inscribed in the image box)

Basis functions used

FLEBasis2D can be forced to use the FBBasis2D heuristic to determine whether the same subset of basis functions are chosen.

The two classes store the 3 parameters determining the basis function (ell, k, and sign) differently in FBBasis2D._indices vs FLEBasis2D.ells/.ks but compare in the following way:

L = 32
fle = FLEBasis2D(L, match_fb=True, dtype=np.float64)
fb = FBBasis2D(L, dtype=np.float64)

fb_fns = []
for idx, ell in enumerate(fb._indices["ells"]):
    fb_fns.append((ell, fb._indices["ks"][idx], fb._indices["sgns"][idx]))

def sign(x):
    if ell==0:
        return 1
    else:
        return np.sign(x)

fle_fns = []
for idx, ell in enumerate(fle.ells):
    fle_fns.append((abs(ell), fle.ks[idx]-1, sign(ell)))

The code above reveals two things to note:

When these are corrected for, we seem to have the same set of functions:

>>> set(fle_fns) == set(fb_fns)
True

The following code sorts both sets of tuples and returns the indices after being sorted (like np.argsort but modified to work on ordering tuples) Luckily, fb_indices is simply [0,1,2...count] in order, because the way FB indices are stored already exactly corresponds to the way Python sorts tuples by default. So all we need to worry about is fle_indices THIS IS NOT TRUE, YOU DO HAVE TO WORRY fb_indices. CODE SNIPPETS UPDATED

fb_indices = sorted(range(len(fb_fns)), key=lambda k: fb_fns[k])
fle_indices = sorted(range(len(fle_fns)), key=lambda k: fle_fns[k])

Visually compare matched basis functions using one-hot tests with evaluate

fb_images_sorted = fb.evaluate(np.eye(fb.count)[fb_indices])
fle_images_sorted = fle.evaluate(np.eye(fb.count)[fle_indices])

Functions are visually similar when k=0 but diverge past that point. Note also that the deltas for those similar images are on a very very small scale (1e-8).

Image(fb_images_sorted[10:20]).show()
Image(fle_images_sorted[10:20]).show()
Image(fb_images_sorted[10:20] - fle_images_sorted[10:20]).show()

FB

fb

FLE

FLE

Deltas

DELTA

This seems likely to be related to the off-by-one discrepancy with ks. To be continued

chris-langfield commented 1 year ago

The solution (but not the explanation)

Summary

Code

The following code is very similar to the original post in this issue, with modifications described above. Also L=8 for easy visualization, but it was verified for higher resolution. (It appears to break down again when the resolution is odd, but one thing at a time.)

L = 8
fle = FLEBasis2D(L, match_fb=True, dtype=np.float64)
fb = FBBasis2D(L, dtype=np.float64)

fb_fns = []
for idx, ell in enumerate(fb._indices["ells"]):
    fb_fns.append((ell, fb._indices["ks"][idx], fb._indices["sgns"][idx]))

## First modification: flip signs (note this affects ORDERING, not real sign of resulting function)
fle_fns = []
def sign(x):
    if ell==0:
        return -1
    else:
        return np.sign(x)
for idx, ell in enumerate(fle.ells):
    fle_fns.append((abs(ell), fle.ks[idx]-1, -sign(ell)))

fb_indices = sorted(range(len(fb_fns)), key=lambda k: fb_fns[k])
fle_indices = sorted(range(len(fle_fns)), key=lambda k: fle_fns[k])

## Second modification: take sorted fb_indices into account
fb_images_sorted = fb.evaluate(np.eye(fb.count)[fb_indices])
fle_images_sorted = fle.evaluate(np.eye(fb.count)[fle_indices])

## Third modification: flip real sign of "negative" FLE functions
for i in range(fle_images_sorted.data.shape[0]):
    if not fle_fns[fle_indices[i]][0]==0 and fle_fns[fle_indices[i]][2] == -1:
        fle_images_sorted.data[i,:,:] = -fle_images_sorted.data[i,:,:]

Now plot:

fb_images_sorted.show()

fb_8

fle_images_sorted.show()

fle_8

Deltas for completeness

(fb_images_sorted - fle_images_sorted).show()

deltas_8

Tolerance:

>>> np.allclose(fb_images_sorted.asnumpy(), fle_images_sorted.asnumpy(), atol=1e-4)
True
chris-langfield commented 1 year ago

Odd resolutions

For odd resolution, the functions ordered in the way above look similar, but not the same (at least they appear to have the same angular and radial frequency, but are just numerically different)

See L=9

FB

fb_9

FLE

fle_9

L=33 (indices 20:45)

FB

fb_33_20--45

FLE

fle_33_20--45

Deltas

deltas_33_20--45

garrettwrong commented 1 year ago

Nice work Chris! Seems like a fine way to round out the week.

chris-langfield commented 1 year ago

Updated notebook

There are two remaining questions regarding the FLE basis

The notebook below contains code creating a one-hot stack of coefficients to test out evaluate between FBBasis2D, FLEBasis2D, and the original implementation with fle_2d, both visually and quantitatively.

Some small modifications to the fle_2d original code were necessary for the comparison to be possible:

The easiest way to add this into the notebook is to clone https://github.com/chris-langfield/fle_2d and pip install -e . from the cloned repo. It will then import as fle_2d in the notebook.

Finally, note that the FB compatibility indexing and sign flipping has been added to #693 so that the FLE outputs are automatically reordered to match FB.

FLE-FB2D-FLE_org-comparison.ipynb.gz

chris-langfield commented 1 year ago

I've narrowed down that the problem with the FLE slow (matrix multiplication) is not a normalization issue. The output of the dense matrix method is close to FLE fast and FB up to sign

The following code will give True, True, True (in e.g. cell 9 of the notebook)

# FB to FLE fast comparison
(np.allclose(fb_images.asnumpy(), fle_images.asnumpy(), atol=1e-4),
# FB to FLE slow comparison
np.allclose(np.abs(fb_images.asnumpy()), np.abs(fle_images_slow.asnumpy()), atol=1e-4),
# FLE slow to FLE fast comparison
np.allclose(np.abs(fle_images_slow.asnumpy()), np.abs(fle_images.asnumpy()), atol=1e-6)
)

This points to a problem with the application of fle.flip_sign_indices in FLEBasis2D.create_dense_matrix(). Initially the basis functions looked visually similar but I hadn't noticed that some of them are flipped. I've been trying out a few ideas (adjusting the order of applying fb_compat_indices and flip_sign_indices, etc.), but haven't been able to quite get it yet.

Not sure if would be more clear to replace the checks in the #693 tests with the np.abs comparisons with a note that the discrepancy is up to the signs of certain basis functions? (rather than putting the 1e-1 tolerance, which we now know the reason for)

garrettwrong commented 1 year ago

In my review (which is pending) i think there was a bug with that code. Specifically, it looks like you negate only half the indices. Let me see if I can find it.

garrettwrong commented 1 year ago

(. removed comment, following discussion with Chris in daily meeting and his detail below; seems we don't think that area is a problem.)

chris-langfield commented 1 year ago

That specific part actually does flip the right signs as far as I can tell. Without flipping the sign at those indices you get the below

FB FB_6

FLE FLE_6

The images at indices 5, 6, 9, 10, 12, 14, 16 are flipped. But when you put that line back in, they're ~1e-5 close.

The flip_sign_indices are [ 2, 4, 7, 9, 11, 13, 16] but those are the indices prior to the fb_compat_indices reordering

I believe the issue it where/how these same indices are mapped to the matrix here

https://github.com/ComputationalCryoEM/ASPIRE-Python/blob/00686f7c564766d2b1f530157a267d6be397ee93/src/aspire/basis/fle_2d.py#L648-L653

chris-langfield commented 1 year ago

Radial convolution in FLEBasis2D

As noted in the discussion in #693, no CTF functions generated by our code (RadialCTFFilter.evaluate_grid()) gives the same result as the hardcoded ctf_32x32.npy file. (see the last test in test_FLEBasis2D

Some notes from things I tried:

chris-langfield commented 1 year ago

Discussed in dev meeting:

Also noted that the original fle_2d package defaults to the slow matrix method for L<16. Discussed that the issues with the matrix method for reordering will need to be addressed before implementing for ASPIRE's FLE.

chris-langfield commented 1 year ago

858 solves the odd grid issue for fast evaluate and evaluate_t. Using create_dense_matrix() in match_fb still results in some images / coefficients being flipped, which I believe is the last correspondence issue.