ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
3 stars 0 forks source link

PCA basis #249

Closed nix-apollo closed 8 months ago

nix-apollo commented 9 months ago

PCA

Description

This implements the analogous version of the "svd" but with centring the activations first. This involves:

Related Issue

First half of #248

Motivation and Context

This is the first step in preparing to compute centred rib activations. The next steps for a "centred-rib" basis would use this

How Has This Been Tested?

Does this PR introduce a breaking change?

I don't think so.

nix-apollo commented 9 months ago

Reverting commit as it was intended to go on a different branch

nix-apollo commented 9 months ago

Re: bias positions

Hmm yes I had underestimated the complexity here. The hook function itself can inspect the module it's attached to and I think it's not terrible to special case some modules in the hook function. I do think it's worth moving towards a world where every module we use has properties that reveal the input shapes, output shapes, and bias positions.

Unfortunately you'd have to pass this information to the next module because all of our stuff happens in pre_forward hooks.

Shouldn't we be able to determine the input bias positions for every module as well? Adjacent modules should agree about where the bias positions are. It's a bit redundant to have both but I think it's fine. If we only have one I'd probably go for the input bias positions since we operate in pre-hook land.

The options that seem most appealing to me are:

  1. Add a Sequential Component class which declares these properties and logic to various components that compute it. I think this is the nicest solution. I think it should be "not that bad" to implement, but increases the scope of this PR and maybe we want to more carefully think through the design decisions. Would allow other cleanup in a future PR.
  2. Add some logic that special cases components in the hook. Mostly this involves:
    • asserting the module hasfoldedbias, which I think ensures it originally had a bias and expects a 1 in it's input somewhere.
    • special casing AttentionOut which expects bias in different positions.
    • looking over other components and see if there are other special cases that need handling.

I'm not sure which is nicer. I prefer 1 long term but also don't really want to increase the scope of this PR more than needed.

Regardless this should have some better tests. For instance, for every module type in pythia asserting that means[bias_positions] = 1 and means[non_bias_positions] != 1

nix-apollo commented 9 months ago

I'll also note that this TODO in the PR description hasn't yet been implemented:

Computing the mean activations of the dataset. We have this mean ignore the 1s appended to the activations when folding in bias [TODO: fix this]

This seems reasonably important, especially for toy models (the logic shouldn't be that bad?)

Thanks for flagging this. I decided to take a different approach by returning the mean for all positions but tracking the bias positions separately. Then when shifting by the bias in collect_gram_matricies and create_shift_matrix we ignore these positions. I think this is more intuitive (the means returned are actually the means). But it does require passing an extra argument around.

If the modules track their own bias positions we will not need to pass the argument around, at least.

danbraunai-apollo commented 9 months ago

Shouldn't we be able to determine the input bias positions for every module as well? Adjacent modules should agree about where the bias positions are.

This is true, one thing we have to handle though is that, if we want to have a property of the input bias positions, modules are also going to need to know whether their adjacent modules have a folded bias property on or not. We can just assume that there is always a folded bias property anytime we run RIB. Or assume that the adjacent module has folded_bias = True if the current module does (that's a bit gross though).