mjo22 / cryojax

Cryo electron microscopy image simulation and analysis built on JAX.
https://mjo22.github.io/cryojax/
GNU Lesser General Public License v2.1
30 stars 13 forks source link

Bug that stops gradients with respect to variables that are used to operate on coordinate systems #194

Closed mjo22 closed 7 months ago

mjo22 commented 7 months ago

Introduced a bug in the coordinate system code. I thought it would be wise to run jax.lax.stop_gradient when getting coordinate systems wrapped in AbstractCoordinates objects (see method cryojax.coordinates.FrequencyGrid.get() for example) because I was thinking of coordinate systems as buffers. However, this has the effect of stopping gradients with respect to variables that operate on coordinate systems, such as coordinate system rotations and pixel size multiplications (e.g. cryojax.simulator.AxisAnglePose.euler_vector and cryojax.simulator.ImageConfig.pixel_size).