ml-struct-bio / drgnai

GNU General Public License v3.0
24 stars 3 forks source link

Pose search bug involving torch.div() when choosing kept poses #8

Closed michal-g closed 2 months ago

michal-g commented 2 months ago

We found a bug affecting training using pose_estimation: abinit in which the indices of poses kept between pose search iterations were being incorrectly updated.

In a previous version there was an update to the torch.div() calls in src/pose_search.py in which we sought to get rid of warning messages associated with using the standard __floordiv__ operator in place of the torch.div() method. On line 570, we thus replaced keep_b = keep_bn * batch_size // loss.shape[0] with keep_b = keep_bn * torch.div(batch_size, loss.shape[0], rounding_mode='trunc').

However, this messed up the order-of-operations, as the * is supposed to be executed before the division. Since loss.shape[0] equals batch_size * nkeptposes, the torch.div() call produces 0, and keep_b is just a index of all 0s. This leads to the batch being filled with just the first image and CTF for subsequent iterations.

We have since updated this line to keep_b = torch.div(keep_bn * batch_size, loss.shape[0], rounding_mode='trunc') in order to fix the order-of-operations issue. We have also added more unit tests and regression tests to ensure that the pose search is now consistent with the original method!