PolarizedLightFieldMicroscopy / GeoBirT

Geometrical Birefringence Tomography
BSD 3-Clause "New" or "Revised" License
4 stars 2 forks source link

Propagate gradients through a subset of an array #67

Closed gschlafly closed 5 months ago

gschlafly commented 7 months ago

Description

Suppose the birefringence array is the concatenation of two arrays. We want to be able to only backpropagate the gradients of one of the two subset arrays.

Here is an example preparation of the two subset arrays.

Delta_n = volume.Delta_n
length = Delta_n.size(0)
half_length = length // 2
volume.Delta_n_first_part = torch.nn.Parameter(Delta_n[:half_length].clone(), requires_grad=True)
volume.Delta_n_second_part = torch.nn.Parameter(Delta_n[half_length:].clone(), requires_grad=False)

The optimizer is created as follows where the trainable_parameters includes the nn.Parameter volume.Delta_n_first_part.

torch.optim.Adam(trainable_parameters, lr=training_params['lr'])

In each iteration, we do the following before the applying the forward model:

Delta_n_combined = torch.cat([volume.Delta_n_first_part, volume.Delta_n_second_part], dim=0)
Delta_n_combined.retain_grad()
volume.birefringence = Delta_n_combined

In the forward pass, the calculations are done with volume.birefringence instead of volume.Delta_n.

Files

To Reproduce

Go to the alt_delta branch. The latest commit is 7e6951f00c9e14995372e40343be150285988a37. Set OPTIMIZING_MODE = False in birefringence_implementations.py, and DEBUG_MODE = True in reconstruction.py. Then, view the print statements after loss.backward(). Also, it should be able to be observed that only a subset of the birefringence values updates across iterations.

gschlafly commented 5 months ago

In 50aa4c2, I was able to causes only a subset of the birefringence array to be optimized by applying a mask to Delta_n.grad before optimizer.step().

gschlafly commented 5 months ago

For pytorch backend, the shifted collision voxels are stored in the dictionary rays.vox_indices_ml_shifted.

gschlafly commented 5 months ago

In commit https://github.com/PolarizedLightFieldMicroscopy/GeoBirT/commit/57a56075621c91c728a60dab592029478169171f, the class instance attribute vox_indices_by_mla_idx was created, which contains the voxels involved with each ray and microlens.