this PR felt very leetcode-flavored (which is to say that coming up with the idea was fun, and then most of the implementation was index chasing)
the problem here is that you want an online process for binning gradients, since you can't just save all of them due to memory usage, so what this does instead is:
use the first draw to get an initial estimate of the number and size of bins needed
on subsequent draws, if there are values that fall out of the range of bins, then add more bins on the necessary side
if too many bins exist, merge adjacent bins until the number of bins falls between min_bins and min_bins * 2
a nice consequence of this is that the bins are chosen relative to all of the gradients from all model params, so you can tell whether a particular layer's gradients are big or small relative to other gradients in the model without having to explicitly compare them
a downside might be that if you have a lot of draws and a lot of layers, then some gradient in some other layer might just blow out the y-axis. this isn't super easy to deal with, since you'd kind of have to just re-run the sampling process bc making the histograms in an online fashion destroys the info you need to reconstruct them with different bins
maybe one idea would be to scale the magnitudes of gradients away from the origin logarithmically? I didn't think about that while implementing it, seems like doable but maybe complicated. maybe wait and see if this is a problem and patch it later?
here is another test graph, in addition to the new example in the diagnostics notebook:
this PR felt very leetcode-flavored (which is to say that coming up with the idea was fun, and then most of the implementation was index chasing)
the problem here is that you want an online process for binning gradients, since you can't just save all of them due to memory usage, so what this does instead is:
min_bins
andmin_bins * 2
a nice consequence of this is that the bins are chosen relative to all of the gradients from all model params, so you can tell whether a particular layer's gradients are big or small relative to other gradients in the model without having to explicitly compare them
a downside might be that if you have a lot of draws and a lot of layers, then some gradient in some other layer might just blow out the y-axis. this isn't super easy to deal with, since you'd kind of have to just re-run the sampling process bc making the histograms in an online fashion destroys the info you need to reconstruct them with different bins
maybe one idea would be to scale the magnitudes of gradients away from the origin logarithmically? I didn't think about that while implementing it, seems like doable but maybe complicated. maybe wait and see if this is a problem and patch it later?
here is another test graph, in addition to the new example in the diagnostics notebook: