ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Fix/negative variance edge ablation #323

Closed stefan-apollo closed 7 months ago

stefan-apollo commented 8 months ago

If we perform edge-ablations we can produce a negative value in the variance node. We expect that such strong ablations destroy performance. While we could implement a special case to return bad loss if we find a negative variance, we think it's easier to just set negative variances to zero -- this should suitably blow up the layer norm scale, and thus produce a bad loss if and only if the layer norm scale was important. var = torch.relu(var)

nix-apollo commented 7 months ago

Oops. Changing base doesn't work since there are many commits on this branch that were squash-merged into main already. Will probably need to cherry-pick just a85db04 in a new PR.

stefan-apollo commented 7 months ago

I'd like to raise a warning, but I'm worried >=1 warning per model run will be super annoying and mess up our terminal a lot. I see the point though. Happy either way I guess?

nix-apollo commented 7 months ago

By default python filters all sequential warnings from the same line, after it issues the first one. More info here. I'm not sure how it works for logger.warn.

stefan-apollo commented 7 months ago

I get the warning a few times -- this seems good enough image

Will implement this in a bit, cherry pick 63f41508964ddd79deb2d18620a3b20ef530e676

stefan-apollo commented 7 months ago

Closed for #346