ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
3 stars 0 forks source link

Support different bases and edges formulas #219

Closed stefan-apollo closed 9 months ago

stefan-apollo commented 9 months ago

Re-implementing hacky changes from feature/new_norm_november_A and feature/new_norm_november_B See doc for details about the corresponding math, and the overleaf should soon also have info.

Replaces #188.

Changes:

Tests:

stefan-apollo commented 9 months ago

So far implemented alternative integrated_gradient_trapezoidal_norm formula. Only changes attn slightly, as expected by Lucius.

image

Next up: Changes in edges

stefan-apollo commented 9 months ago

I see to have introduced a sign error in test_integrated_gradient_trapezoidal_jacobian_jacrev, can't figure out why yet

stefan-apollo commented 9 months ago

Found the bug, with Jake's help: In the test we were computing the old version of the formula (without the const - ...). This happened in both the jacrev function _integrated_gradient_jacobian_with_jacrev and our function integrated_gradient_trapezoidal_jacobian. But with the new changes integrated_gradient_trapezoidal_jacobianalways implements the new (const - ... alpha) one while the test had the old one hardcoded. Thus the test and our function disagreed. Updating _integrated_gradient_jacobian_with_jacrev to match fixes this.

stefan-apollo commented 9 months ago

test_modular_arithmetic_build_graph is supposed to fail now with the new bases, according to Lucius. This property should no longer hold

stefan-apollo commented 9 months ago

The floating point test for the interaction_rotations fails for the new basis. However the test was already mostly failing for the old basis. So I suggest we scrap that particular test.

nix-apollo commented 9 months ago

Test runtime is rather large now, will reduce number of tests before merging 42.82s call tests/test_build_graph.py::test_mnist_build_graph_new_basis_old_attribution 42.65s call tests/test_build_graph.py::test_mnist_build_graph_old_basis_old_attribution 42.58s call tests/test_build_graph.py::test_mnist_build_graph_new_basis_and_new_attribution

With #223 at least mnist build tests can be made faster by reducing data amount

stefan-apollo commented 9 months ago

And yeah these tests take a long time, maybe we can disable some of them on CI / reduce data amount later

With https://github.com/ApolloResearch/rib/pull/223 at least mnist build tests can be made faster by reducing data amount

80.00s call     tests/test_build_graph.py::test_mnist_build_graph_old_basis_old_attribution[(1-0)*alpha-squared]
76.74s call     tests/test_build_graph.py::test_mnist_build_graph_old_basis_old_attribution[(1-alpha)^2-squared]
43.66s call     tests/test_build_graph.py::test_mnist_build_graph_old_basis_old_attribution[(1-alpha)^2-functional]
43.15s call     tests/test_build_graph.py::test_mnist_build_graph_old_basis_old_attribution[(1-0)*alpha-functional]
27.06s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_gram_matrices
20.19s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_float_precision
17.76s call     tests/test_build_graph.py::test_pythia_14m_build_graph
12.01s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-alpha)^2-squared]
11.77s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-alpha)^2-functional]
11.64s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-0)*alpha-functional]
11.35s call     tests/test_ablations.py::test_run_mnist_ablations[orthogonal]
11.01s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-0)*alpha-squared]
8.81s call     tests/test_parallel.py::TestDistributed::test_edges_are_same
7.84s call     tests/test_ablations.py::test_run_mnist_ablations[rib]
6.84s call     tests/test_train_mnist.py::test_main_accuracy
stefan-apollo commented 9 months ago

I do have a problem with the variable_position_dimension which has 12 occurances in the repo, and is only there in order to initialise this thing:

Oh I absolutely agree with this, I somehow missed your comment -- will change that

stefan-apollo commented 9 months ago

Test finished, merging!