timaeus-research / devinterp

Tools for studying developmental interpretability in neural networks.
71 stars 14 forks source link

Add gradient binning #57

Closed georgeyw closed 10 months ago

georgeyw commented 10 months ago

this PR felt very leetcode-flavored (which is to say that coming up with the idea was fun, and then most of the implementation was index chasing)

the problem here is that you want an online process for binning gradients, since you can't just save all of them due to memory usage, so what this does instead is:

a nice consequence of this is that the bins are chosen relative to all of the gradients from all model params, so you can tell whether a particular layer's gradients are big or small relative to other gradients in the model without having to explicitly compare them

a downside might be that if you have a lot of draws and a lot of layers, then some gradient in some other layer might just blow out the y-axis. this isn't super easy to deal with, since you'd kind of have to just re-run the sampling process bc making the histograms in an online fashion destroys the info you need to reconstruct them with different bins

maybe one idea would be to scale the magnitudes of gradients away from the origin logarithmically? I didn't think about that while implementing it, seems like doable but maybe complicated. maybe wait and see if this is a problem and patch it later?

here is another test graph, in addition to the new example in the diagnostics notebook:

image

georgeyw commented 10 months ago

Updated with Stan's feedback (thanks!). Happy to merge once there's an approval stamp on it.