If we perform edge-ablations we can produce a negative value in the variance node. We
expect that such strong ablations destroy performance. While we could implement a special
case to return bad loss if we find a negative variance, we think it's easier to just set
negative variances to zero -- this should suitably blow up the layer norm scale, and thus
produce a bad loss if and only if the layer norm scale was important.
var = torch.relu(var)
We also log a warning whenever this happens.
This PR is a replacement of #323 because that was easier than merging main
If we perform edge-ablations we can produce a negative value in the variance node. We expect that such strong ablations destroy performance. While we could implement a special case to return bad loss if we find a negative variance, we think it's easier to just set negative variances to zero -- this should suitably blow up the layer norm scale, and thus produce a bad loss if and only if the layer norm scale was important.
var = torch.relu(var)
We also log a warning whenever this happens.
This PR is a replacement of #323 because that was easier than merging main