Closed stefan-apollo closed 6 months ago
Example run on mod add -- we can always see the 1-node step where it jumps from >0.95 to <0.95 accuracy
"results": {
"ln1.0": {
"65": 1.0,
"33": 1.0,
"17": 0.9992167101827676,
"13": 0.9892950391644909,
"11": 0.954046997389034, <--
"10": 0.9409921671018276,
"9": 0.8861618798955614
},
"ln2.0": {
"65": 1.0,
"33": 1.0,
"17": 1.0,
"15": 0.9942558746736292,
"14": 0.9684073107049609, <--
"13": 0.9417754569190601,
"9": 0.6966057441253264
},
"mlp_out.0": {
"321": 1.0,
"161": 1.0,
"81": 1.0,
"41": 1.0,
"21": 1.0,
"11": 1.0,
"6": 0.9798955613577024, <--
"5": 0.6652741514360313,
"3": 0.16605744125326372
},
"unembed": {
"65": 1.0,
"33": 1.0,
"17": 1.0,
"9": 1.0,
"7": 0.9911227154046998,
"6": 0.9798955613577024, <--
"5": 0.6652741514360313
}
},
And for edges -- this would otherwise be very expensive!
"results": {
"ln1.0": {
"8321": 0.9953002610966057,
"4161": 0.9953002610966057,
"2081": 0.9953002610966057,
"1041": 0.9953002610966057,
"521": 0.9953002610966057,
"261": 0.995822454308094,
"131": 0.9642297650130548,
"123": 0.9624020887728459,
"119": 0.9660574412532638,
"117": 0.9639686684073107, <--
"116": 0.9498694516971279,
"115": 0.9488250652741514,
"99": 0.8960835509138381,
"66": 0.6657963446475196
},
"ln2.0": {
"41409": 1.0,
"20705": 1.0,
"10353": 1.0,
"5177": 1.0,
"2589": 1.0,
"1295": 1.0,
"648": 1.0,
"324": 1.0,
"162": 1.0,
"81": 0.9994778067885117,
"61": 0.9845953002610967,
"56": 0.9516971279373369,
"55": 0.9506527415143603, <--
"54": 0.943864229765013,
"51": 0.9057441253263707,
"41": 0.860313315926893
},
"mlp_out.0": {
"41409": 1.0,
"20705": 1.0,
"10353": 1.0,
"5177": 1.0,
"2589": 1.0,
"1295": 1.0,
"648": 1.0,
"324": 1.0,
"162": 1.0,
"81": 1.0,
"41": 1.0,
"21": 1.0,
"11": 1.0,
"6": 0.9798955613577024, <--
"5": 0.6652741514360313,
"3": 0.16605744125326372
}
}
Todo: Get's stuck with logarithmic scaling!
Fixed -- had forgotten .step() at one point!
exp_name: rib_modular_arithmetic_edge
ablation_type: edge
rib_results_path: rib_scripts/rib_build/sample_graphs/modular_arithmetic_rib_graph_sample.pt
schedule:
schedule_type: bisect
scaling: logarithmic
score_type: accuracy
score_target: 0.95
dataset:
dataset_type: modular_arithmetic
return_set: train
ablation_node_layers: # Rotate the input to these modules into the interaction basis
- ln1.0
- ln2.0
- mlp_out.0
- unembed
batch_size: 128
dtype: float32
eval_type: accuracy
seed: 0
All changes implemented
Refactor ablation schedules & implement bisect schedule
Description
self.config.n_points <= self._n_vecs
but actually, for e.g. n_vecs=2 we want to allow n_points=3 [0,1,2] so changed toself.config.n_points <= self._n_vecs + 1
exp_base = self.exp_base if self.exp_base is not None else 2.0
and giveexp_base
a default value insteadtqdm(ablation_schedule[::-1],
by applying[::-1]
internallyMotivation and Context
The main metrics for ablation tests are of the type "how many edges do we need for 99.9% performance". Rather than zooming into badly sampled graphs, we should just bisect this.
How Has This Been Tested?
Questions