PolarizedLightFieldMicroscopy / GeoBirT

Geometrical Birefringence Tomography
BSD 3-Clause "New" or "Revised" License
4 stars 2 forks source link

Predicted retardance image resembles rotated measured retardance image #89

Closed gschlafly closed 4 months ago

gschlafly commented 4 months ago

Description

The retardance image of the predicted object sometimes resembles the rotated measured retardance image. With more iterations, the retardance image become only nonzero in the region of overlap between the measured retardance image and the rotated measured retardance image. This issue has been particularly evident with non-symmetrical objects and when using the omit_rays_based_on_pixels=True. The voxels being ray traced through is not always correct.

Screenshots

Experimental xylem cells image Experimental xylem cells after even more iterations image

Simulated shifted birefringent voxel image

Files

To Reproduce

Run the function recon_xylem() in the script xylem.py, and observe the reconstruction results. These xylem images are large, making the process computationally intensive. A less computationally intensive example, that appears to simulate the same issue, is a shifted birefringent voxel. Run the function recon_voxel_shifted() in the script run_recon.py, and observe the reconstruction results.

def recon_voxel_shifted():
    optical_info = setup_optical_parameters(
        "config_settings/optical_config_voxel.json")
    postfix = get_forward_img_str_postfix(optical_info)
    forward_img_str = 'voxel_pos_shiftedy' + postfix + '.npy'
    simulate = True
    if simulate:
        optical_system = {'optical_info': optical_info}
        simulator = ForwardModel(optical_system, backend=BACKEND)
        volume_GT = BirefringentVolume(
            backend=BACKEND,
            optical_info=optical_info,
            volume_creation_args=volume_args.voxel_shiftedy_args
        )
        visualize_volume(volume_GT, optical_info)
        simulator.forward_model(volume_GT)
        simulator.view_images()
        ret_image_meas = simulator.ret_img
        azim_image_meas = simulator.azim_img
        # Save the images as numpy arrays
        if True:
            ret_numpy = ret_image_meas.detach().numpy()
            np.save('forward_images/ret_' + forward_img_str, ret_numpy)
            azim_numpy = azim_image_meas.detach().numpy()
            np.save('forward_images/azim_' + forward_img_str, azim_numpy)
    else:
        ret_image_meas = np.load(os.path.join(
            'forward_images', 'ret_' + forward_img_str))
        azim_image_meas = np.load(os.path.join(
            'forward_images', 'azim_' + forward_img_str))

    recon_optical_info = optical_info.copy()
    iteration_params = setup_iteration_parameters(
        "config_settings/iter_config.json")
    initial_volume = BirefringentVolume(
        backend=BACKEND,
        optical_info=recon_optical_info,
        volume_creation_args=volume_args.random_args
    )
    recon_directory = create_unique_directory("reconstructions")
    if not simulate:
        volume_GT = initial_volume
    recon_config = ReconstructionConfig(recon_optical_info, ret_image_meas,
        azim_image_meas, initial_volume, iteration_params, gt_vol=volume_GT
    )
    recon_config.save(recon_directory)
    reconstructor = Reconstructor(recon_config, omit_rays_based_on_pixels=True, apply_volume_mask=True)
    reconstructor.reconstruct(output_dir=recon_directory)
    visualize_volume(reconstructor.volume_pred, reconstructor.optical_info)

It may be helpful to examine the boolean values of the arribute nonzero_pixels_dict initalized here: https://github.com/PolarizedLightFieldMicroscopy/GeoBirT/blob/8002bff0c4cf8263ab04aa18563cb1e689d42b02/VolumeRaytraceLFM/abstract_classes.py#L526-L528

gschlafly commented 4 months ago

From the following code, it appears that the measured retardance image best aligned with the transpose of the predicted image. Thus, the indices of the microlens index should be swapped during the initial assignment process.

# %%
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt

# %%
dir = "xylem_mid_outputs/"
ret_meas = np.load(dir + "ret_meas.npy")
ret_pred = np.load(dir + "ret_pred_omitrays.npy")

# %%
def rotate_and_compare(ret_pred, ret_meas, angle):
    """
    Rotates ret_pred by a given angle, subtracts it from ret_meas,
    and calculates the mean squared error.
    """
    rotated_img = scipy.ndimage.rotate(ret_pred, angle, reshape=False)
    diff = ret_meas - rotated_img
    mse = np.mean(np.square(diff))
    return mse, rotated_img

# %%
angles = range(-180, 181, 10)  # Adjust the range and step of angles as needed
errors = []

for angle in angles:
    mse, _ = rotate_and_compare(ret_pred, ret_meas, angle)
    errors.append(mse)

# Find the angle with the minimum error
min_error_idx = np.argmin(errors)
best_angle = angles[min_error_idx]

# Rotate the image using the best angle found
_, best_rotated_img = rotate_and_compare(ret_pred, ret_meas, best_angle)

# %%
# Plotting
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(ret_pred, cmap='gray')
axs[0].set_title('Original')
axs[1].imshow(best_rotated_img, cmap='gray')
axs[1].set_title(f'Best Rotation: {best_angle} degrees')
axs[2].imshow(ret_meas, cmap='gray')
axs[2].set_title('Target')
plt.show()

print(f"Best rotation angle: {best_angle} degrees")

# %% [markdown]
# ## Check mirroring

# %%
def calculate_mse(image1, image2):
    """Calculate the mean squared error between two images."""
    return np.mean((image1 - image2) ** 2)

def mirror_along_main_diagonal(image):
    """Mirror an image along its main diagonal. This works by default for square matrices."""
    return np.transpose(image)

def mirror_along_secondary_diagonal(image):
    """Mirror an image along its secondary diagonal. For square matrices."""
    return np.fliplr(np.flipud(np.transpose(image)))

# %%
# Prepare the mirrored versions
mirrored_horizontally = np.fliplr(ret_pred)
mirrored_vertically = np.flipud(ret_pred)
mirrored_main_diagonal = mirror_along_main_diagonal(ret_pred)
mirrored_secondary_diagonal = mirror_along_secondary_diagonal(ret_pred)

# Calculate MSE for each mirrored version
mse_original = calculate_mse(ret_pred, ret_meas)
mse_horizontal = calculate_mse(mirrored_horizontally, ret_meas)
mse_vertical = calculate_mse(mirrored_vertically, ret_meas)
mse_main_diagonal = calculate_mse(mirrored_main_diagonal, ret_meas)
mse_secondary_diagonal = calculate_mse(mirrored_secondary_diagonal, ret_meas)

# Find which version has the lowest MSE (i.e., best match)
mse_values = [mse_original, mse_horizontal, mse_vertical, mse_main_diagonal, mse_secondary_diagonal]
best_match_index = np.argmin(mse_values)
options = ["Original", "Horizontally Mirrored", "Vertically Mirrored", "Main Diagonal Mirrored", "Secondary Diagonal Mirrored"]
best_description = options[best_match_index]

# %% [markdown]
# ### Plot

# %%
# Prepare the best matching image for plotting
best_matching_image = [ret_pred, mirrored_horizontally, mirrored_vertically, mirrored_main_diagonal, mirrored_secondary_diagonal][best_match_index]

# %%
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(ret_pred, cmap='gray')
axs[0].set_title('Original')
axs[1].imshow(best_matching_image, cmap='gray')
axs[1].set_title(f'Best Match: {best_description}')
axs[2].imshow(ret_meas, cmap='gray')
axs[2].set_title('Target')
plt.show()

print(f"Best match: {best_description} with MSE: {mse_values[best_match_index]}")
gschlafly commented 4 months ago

Including mla_index = (mla_index[1], mla_index[0]) within beginning the method: https://github.com/PolarizedLightFieldMicroscopy/GeoBirT/blob/615432d514df3387db9a547718899372ff6c897f/VolumeRaytraceLFM/birefringence_implementations.py#L1794-L1819 solves the issue. More directly, we will define the mla_index with the correct index ordering.