timaeus-research / devinterp

Tools for studying developmental interpretability in neural networks.
71 stars 14 forks source link

Fixing two issues in MalaAcceptanceRate callback #87

Closed rohan-hitchcock closed 2 months ago

rohan-hitchcock commented 2 months ago

This fixes two issues:

  1. The MALA acceptance rate was only being computed using the last parameter tensor in the model's parameter list.
  2. The MalaAcceptanceRate callback did not work correctly for models with some parameter gradients disabled using requires_grad = False. In most cases this resulted in a RuntimeError due to mismatched tensor shapes.

It also fixed the docstring for MalaAcceptanceRate, which did not correctly describe the callback.

These issues were addressed as follows:

  1. Changed the mala_acceptance_probability function to accept either tensors or lists of tensors current and previous points and gradients, and updated the MalaAcceptanceRate update method accordingly. This function can still be used in the same way.
  2. Filtered out parameters with requires_grad == False when creating the self.current_params list in MalaAcceptanceRate. I also removed converting model.parameters() to a list inside of this list comprehension as this looked unnecessary?
svwingerden commented 2 months ago

Looks great, though I should really add some tests for the failure mode that this PR fixes. Thanks for the PR!