We found a bug affecting training using pose_estimation: abinit in which the indices of poses kept between pose search iterations were being incorrectly updated.
In a previous version there was an update to the torch.div() calls in src/pose_search.py in which we sought to get rid of warning messages associated with using the standard __floordiv__ operator in place of the torch.div() method. On line 570, we thus replaced keep_b = keep_bn * batch_size // loss.shape[0] with keep_b = keep_bn * torch.div(batch_size, loss.shape[0], rounding_mode='trunc').
However, this messed up the order-of-operations, as the * is supposed to be executed before the division. Since loss.shape[0] equals batch_size * nkeptposes, the torch.div() call produces 0, and keep_b is just a index of all 0s. This leads to the batch being filled with just the first image and CTF for subsequent iterations.
We have since updated this line to keep_b = torch.div(keep_bn * batch_size, loss.shape[0], rounding_mode='trunc') in order to fix the order-of-operations issue. We have also added more unit tests and regression tests to ensure that the pose search is now consistent with the original method!
We found a bug affecting training using
pose_estimation: abinit
in which the indices of poses kept between pose search iterations were being incorrectly updated.In a previous version there was an update to the
torch.div()
calls insrc/pose_search.py
in which we sought to get rid of warning messages associated with using the standard__floordiv__
operator in place of thetorch.div()
method. On line 570, we thus replacedkeep_b = keep_bn * batch_size // loss.shape[0]
withkeep_b = keep_bn * torch.div(batch_size, loss.shape[0], rounding_mode='trunc')
.However, this messed up the order-of-operations, as the
*
is supposed to be executed before the division. Sinceloss.shape[0]
equalsbatch_size * nkeptposes
, thetorch.div()
call produces 0, andkeep_b
is just a index of all 0s. This leads to the batch being filled with just the first image and CTF for subsequent iterations.We have since updated this line to
keep_b = torch.div(keep_bn * batch_size, loss.shape[0], rounding_mode='trunc')
in order to fix the order-of-operations issue. We have also added more unit tests and regression tests to ensure that the pose search is now consistent with the original method!