The MALA acceptance rate was only being computed using the last parameter tensor in the model's parameter list.
The MalaAcceptanceRate callback did not work correctly for models with some parameter gradients disabled using requires_grad = False. In most cases this resulted in a RuntimeError due to mismatched tensor shapes.
It also fixed the docstring for MalaAcceptanceRate, which did not correctly describe the callback.
These issues were addressed as follows:
Changed the mala_acceptance_probability function to accept either tensors or lists of tensors current and previous points and gradients, and updated the MalaAcceptanceRate update method accordingly. This function can still be used in the same way.
Filtered out parameters with requires_grad == False when creating the self.current_params list in MalaAcceptanceRate. I also removed converting model.parameters() to a list inside of this list comprehension as this looked unnecessary?
This fixes two issues:
requires_grad = False
. In most cases this resulted in a RuntimeError due to mismatched tensor shapes.It also fixed the docstring for MalaAcceptanceRate, which did not correctly describe the callback.
These issues were addressed as follows:
mala_acceptance_probability
function to accept either tensors or lists of tensors current and previous points and gradients, and updated the MalaAcceptanceRateupdate
method accordingly. This function can still be used in the same way.requires_grad == False
when creating theself.current_params
list in MalaAcceptanceRate. I also removed convertingmodel.parameters()
to a list inside of this list comprehension as this looked unnecessary?