ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
3 stars 0 forks source link

Normalize the edges when calculating #187

Closed danbraunai-apollo closed 10 months ago

danbraunai-apollo commented 10 months ago

We normalize the gradients when calulating the Cs, but don't yet do it for the Edges.

This is the current code:

    # Calculate the square and sum over the pos dimension if it exists.
    f_out_hat_norm: Float[Tensor, "... out_hidden_combined_trunc"] = f_out_hat**2
    if has_pos:
        # f_out_hat is shape (batch, pos, hidden)
        assert f_out_hat.dim() == 3, f"f_out_hat should have 3 dims, got {f_out_hat.dim()}"
        f_out_hat_norm = f_out_hat_norm.sum(dim=1)

    # Sum over the batch dimension
    f_out_hat_norm = f_out_hat_norm.sum(dim=0)

Instead, I think we should take the mean over these dimensions.

First investigate the sizes before and after this adjustment and see if it has a stabilising effect.