ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Fix/negative variance edge ablation [Replacement for #323] #346

Closed stefan-apollo closed 7 months ago

stefan-apollo commented 7 months ago

If we perform edge-ablations we can produce a negative value in the variance node. We expect that such strong ablations destroy performance. While we could implement a special case to return bad loss if we find a negative variance, we think it's easier to just set negative variances to zero -- this should suitably blow up the layer norm scale, and thus produce a bad loss if and only if the layer norm scale was important. var = torch.relu(var)

We also log a warning whenever this happens.

This PR is a replacement of #323 because that was easier than merging main

stefan-apollo commented 7 months ago

Confirmed identical Files changed tab to #323