ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
3 stars 0 forks source link

[RFR] Bisect ablation scheduler #303

Closed stefan-apollo closed 6 months ago

stefan-apollo commented 6 months ago

Refactor ablation schedules & implement bisect schedule

Description

Motivation and Context

The main metrics for ablation tests are of the type "how many edges do we need for 99.9% performance". Rather than zooming into badly sampled graphs, we should just bisect this.

How Has This Been Tested?

Questions

stefan-apollo commented 6 months ago

Example run on mod add -- we can always see the 1-node step where it jumps from >0.95 to <0.95 accuracy


    "results": {
        "ln1.0": {
            "65": 1.0,
            "33": 1.0,
            "17": 0.9992167101827676,
            "13": 0.9892950391644909,
            "11": 0.954046997389034, <-- 
            "10": 0.9409921671018276,
            "9": 0.8861618798955614
        },
        "ln2.0": {
            "65": 1.0,
            "33": 1.0,
            "17": 1.0,
            "15": 0.9942558746736292,
            "14": 0.9684073107049609, <-- 
            "13": 0.9417754569190601,
            "9": 0.6966057441253264
        },
        "mlp_out.0": {
            "321": 1.0,
            "161": 1.0,
            "81": 1.0,
            "41": 1.0,
            "21": 1.0,
            "11": 1.0,
            "6": 0.9798955613577024, <-- 
            "5": 0.6652741514360313,
            "3": 0.16605744125326372
        },
        "unembed": {
            "65": 1.0,
            "33": 1.0,
            "17": 1.0,
            "9": 1.0,
            "7": 0.9911227154046998,
            "6": 0.9798955613577024, <-- 
            "5": 0.6652741514360313
        }
    },

And for edges -- this would otherwise be very expensive!


    "results": {
        "ln1.0": {
            "8321": 0.9953002610966057,
            "4161": 0.9953002610966057,
            "2081": 0.9953002610966057,
            "1041": 0.9953002610966057,
            "521": 0.9953002610966057,
            "261": 0.995822454308094,
            "131": 0.9642297650130548,
            "123": 0.9624020887728459,
            "119": 0.9660574412532638,
            "117": 0.9639686684073107, <--
            "116": 0.9498694516971279,
            "115": 0.9488250652741514,
            "99": 0.8960835509138381,
            "66": 0.6657963446475196
        },
        "ln2.0": {
            "41409": 1.0,
            "20705": 1.0,
            "10353": 1.0,
            "5177": 1.0,
            "2589": 1.0,
            "1295": 1.0,
            "648": 1.0,
            "324": 1.0,
            "162": 1.0,
            "81": 0.9994778067885117,
            "61": 0.9845953002610967,
            "56": 0.9516971279373369,
            "55": 0.9506527415143603, <--
            "54": 0.943864229765013,
            "51": 0.9057441253263707,
            "41": 0.860313315926893
        },
        "mlp_out.0": {
            "41409": 1.0,
            "20705": 1.0,
            "10353": 1.0,
            "5177": 1.0,
            "2589": 1.0,
            "1295": 1.0,
            "648": 1.0,
            "324": 1.0,
            "162": 1.0,
            "81": 1.0,
            "41": 1.0,
            "21": 1.0,
            "11": 1.0,
            "6": 0.9798955613577024, <--
            "5": 0.6652741514360313,
            "3": 0.16605744125326372
        }
    }
stefan-apollo commented 6 months ago

Todo: Get's stuck with logarithmic scaling!

Fixed -- had forgotten .step() at one point!

exp_name: rib_modular_arithmetic_edge
ablation_type: edge
rib_results_path: rib_scripts/rib_build/sample_graphs/modular_arithmetic_rib_graph_sample.pt
schedule:
  schedule_type: bisect
  scaling: logarithmic
  score_type: accuracy
  score_target: 0.95
dataset:
  dataset_type: modular_arithmetic
  return_set: train
ablation_node_layers:  # Rotate the input to these modules into the interaction basis
  - ln1.0
  - ln2.0
  - mlp_out.0
  - unembed
batch_size: 128
dtype: float32
eval_type: accuracy
seed: 0
stefan-apollo commented 6 months ago

All changes implemented