timaeus-research / devinterp

Tools for studying developmental interpretability in neural networks.
77 stars 14 forks source link

Refactor num_samples / temperature, rename elasticity to localization, add warnings, init_loss before sampling, OnlineLLC fix #69

Closed svwingerden closed 9 months ago

svwingerden commented 9 months ago

This PR includes a few changes:

  1. I refactored the usage of num_samples and temperature in SGLD (and SGNHT). We now only pass along a temperature, which if not set defaults to the optimal temperature.
  2. I renamed the elasticity parameter in SGLD to localization.
  3. I added a few warnings (burnin, weird temperature)
  4. I fixed OnlineLLCEstimator not to return an average of a moving average
  5. If init_loss is not supplied to sample() or any of the estimate() functions, we calculate it explicitly on one batch. This allows burn-in to work, btw.
  6. Refactored some tests, added in a new seeding test
svwingerden commented 9 months ago

Fixed the requested issues, only major change is both estimate() functions now live in sampler.py rather than in llc.py to get rid of a circular import.