timaeus-research / devinterp

Tools for studying developmental interpretability in neural networks.
74 stars 14 forks source link

Add in Mala Acceptance Criterion #71

Closed svwingerden closed 8 months ago

svwingerden commented 8 months ago

This PR adds/changes five things:

  1. Adds MalaAcceptanceRate callback, to use to diagnose chain health like in Furman & Lau, 2024.
  2. Adds tests to benchmark this implementation against one of Zach Furman's notebooks,
  3. Changes the SGLD and SGNHT classes to allow the MALA rate to be calculated semi-efficiently (also fixed the NoiseNorm callback to work the same way)
  4. Changes sgld_calibration.ipynb notebook that show how and when to use the MALA diagnostic, and runs a lot quicker as well.
  5. Changed diagnostics.ipynb to show how to use MALA, and remove a now-fixed NoiseNorm line.