ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
3 stars 0 forks source link

[WIP] Support jacobian basis (formerly November B) #244

Closed stefan-apollo closed 7 months ago

stefan-apollo commented 9 months ago

Note: Forked from feature/centred-rib!

Todo:

Description

  1. Implements "Jacobian basis"; called this because it g_ij is a Jacobian and M = g.T@g is the eigendecomposed.
  2. Rename basis calculation functions to calcbasis...
  3. Improve tqdm
  4. Replaces our .backward + .grad + .zero_grad_ chain with just autograd.grad in the other basis formulas. I confirmed this is identical for the jacobian basis, looks fine for the others too. Should be identical anyway.
stefan-apollo commented 9 months ago

This basis, unlike the previous bases, is sensitive to the truncation threshold change from 1e-6 to 1e-15. This may or may not be a bug. image

stefan-apollo commented 8 months ago

With the old Lambdas we get results like this, looking similar to those in #276 image

stefan-apollo commented 8 months ago

With the new Lambdas solve this to some extent?

We currently have to set Lambda[0] manually for the bias. The value of Lambda[0] only matters for (a) numerics and (b) the strength of bias node connections.

We will fix the sorting in a later commit, without needing the "set it to a huge value" workaround.

Open question: Were the stray edges numerical all along?

modular_arithmetic_jacobian_new_lambdas_1e9_rib_graph modular_arithmetic_jacobian_new_lambdas_1e5_rib_graph

stefan-apollo commented 8 months ago

With the fix we can set Lambda[0] to 1 without it being sorted somewhere in the middle!

image image

Note: Those images hide the first (constant) dimension otherwise it would sometimes down out everything. And this would depend on Lambda[0]

stefan-apollo commented 8 months ago

The big blob in the "large" (Lambda0=1e9) case went away. image

But Lambda0=1e5 and Lamnbda0=1e0 still differ! There is a smaller hardly-visible blob! You can see it in the log plot here: image

Will make an issue for this to investigate later. It is unclear whether the "stray edges" issue was caused by incorrect Lambdas or by this numerical phenomenon.

stefan-apollo commented 8 months ago

Node: I would like to include Lucius "correct" Lambda[0] value if it's not too much work, to get rid of inconsistent lines like these: image

Although almost always we'd plot centered builds without showing edges to const so not a big deal if not.

stefan-apollo commented 8 months ago

Note: Please convert this Google doc into a markdown file in the repo when finishing this PR.

stefan-apollo commented 7 months ago

Note: I would like to include Lucius "correct" Lambda[0] value if it's not too much work, to get rid of inconsistent lines like these

Done, graph looking fine.

image

stefan-apollo commented 7 months ago

implemented this for non-transformers. Here's the MLP test.

(1-0)*alpha basis: image

Jacobian basis: image

stefan-apollo commented 7 months ago

I will run the following tests on mod add now:

v0

basis_formula: (1-alpha)^2
edge_formula: functional

v1

basis_formula: (1-0)*alpha
edge_formula: squared

v2

basis_formula: jacobian
edge_formula: squared

with this config

exp_name: modular_arithmetic_jacobian_v0
seed: 0
tlens_pretrained: null
tlens_model_path: experiments/train_modular_arithmetic/sample_checkpoints/lr-0.001_bs-10000_norm-None_2023-11-28_16-07-19/model_epoch_60000.pt
interaction_matrices_path: null
dataset:
  source: custom
  name: modular_arithmetic
  return_set: train
node_layers:
  - ln1.0
  - ln2.0
  - mlp_out.0
  - unembed
  - output
rotate_final_node_layer: false
batch_size: 999999
edge_batch_size: 999999
truncation_threshold: 1e-15
last_pos_module_type: add_resid1  # module type in which to only output the last position index
n_intervals: 5
dtype: float64
calculate_edges: true
eval_type: accuracy
center: True
stefan-apollo commented 7 months ago

Ablations, vs SVD & PCA: abla

Graphs: modular_arithmetic_v0_rib_graph modular_arithmetic_v1_rib_graph modular_arithmetic_v2_rib_graph SVD: modular_arithmetic_svd_rib_graph PCA: modular_arithmetic_pca_rib_graph

stefan-apollo commented 7 months ago

CIFAR MLP ablations: image

Graphs: cifar-mlp-v1_rib_graph cifar-mlp-v2_rib_graph

stefan-apollo commented 7 months ago

Replaced with #294 due to refactor