Introduced a bug in the coordinate system code. I thought it would be wise to run jax.lax.stop_gradient when getting coordinate systems wrapped in AbstractCoordinates objects (see method cryojax.coordinates.FrequencyGrid.get() for example) because I was thinking of coordinate systems as buffers. However, this has the effect of stopping gradients with respect to variables that operate on coordinate systems, such as coordinate system rotations and pixel size multiplications (e.g. cryojax.simulator.AxisAnglePose.euler_vector and cryojax.simulator.ImageConfig.pixel_size).
Introduced a bug in the coordinate system code. I thought it would be wise to run
jax.lax.stop_gradient
when getting coordinate systems wrapped inAbstractCoordinate
s objects (see methodcryojax.coordinates.FrequencyGrid.get()
for example) because I was thinking of coordinate systems as buffers. However, this has the effect of stopping gradients with respect to variables that operate on coordinate systems, such as coordinate system rotations and pixel size multiplications (e.g.cryojax.simulator.AxisAnglePose.euler_vector
andcryojax.simulator.ImageConfig.pixel_size
).