ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Fix M matrix floating point issues #203

Closed stefan-apollo closed 11 months ago

stefan-apollo commented 11 months ago

Description

Based on feature/force_overwrite_output, merge that before this PR (Merged)

Hopefully fixes remaining floating point issues.

Fixes two of our floating point issues most of the time.

  1. Turns out that M and M dash need to be kept at float64 at all times (until the eigendecompose). Rounding them by even momentarily converting either to float32 breaks ablation curves.
  2. Turns out that the einsum for Lambda_dash needs to be run in float64.

Tested

Implemented tests that

Future work:

Also manually tested by observing that ablation curves stay flat until 128 with these configs: Build

exp_name: debug-pythia-14m
force_overwrite_output: true
seed: 0
tlens_pretrained: pythia-14m
tlens_model_path: null
dataset:
  source: huggingface
  name: NeelNanda/pile-10k
  tokenizer_name: EleutherAI/pythia-14m
  return_set: train  # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
  return_set_frac: null
  return_set_n_samples: 10
  return_set_portion: first
node_layers:
  - ln1.0
  - mlp_out.0
  - ln2.3
  - mlp_out.3
  - ln1.5
  - mlp_out.5
  - output
batch_size: 4  #  A100 can handle 24
gram_batch_size: 20  #  A100 can handle 80
truncation_threshold: 1e-6
rotate_final_node_layer: false
n_intervals: 10
dtype: float32
calculate_edges: false
eval_type: null

Ablate

exp_name: debug-pythia-14m
force_overwrite_output: true
ablation_type: rib
interaction_graph_path: /mnt/ssd-apollo/stefan/rib/experiments/lm_rib_build/out/debug-pythia-14m_rib_Cs.pt
schedule:
  schedule_type: linear
  early_stopping_threshold: 1.5
  n_points: 20
  specific_points: [128, 129, 130]
dataset:
  source: huggingface
  name: NeelNanda/pile-10k
  tokenizer_name: EleutherAI/pythia-14m
  return_set: train  # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
  return_set_frac: null
  return_set_n_samples: 10
  return_set_portion: first
ablation_node_layers:
  - ln1.0
  - mlp_out.0
  - ln2.3
  - mlp_out.3
  - ln1.5
  - mlp_out.5
batch_size: 30  # A100 can handle 60
dtype: float32
eval_type: ce_loss
seed: 0

Result: debug-pythia-14m_ce_loss_vs_ablated_vecs The results were also tested on a larger dataset of return_set_frac: 0.1.

PS: My debugging run script

#!/bin/bash
set -e
python /mnt/ssd-apollo/stefan/rib/experiments/lm_rib_build/run_lm_rib_build.py /mnt/ssd-apollo/stefan/rib/experiments/lm_rib_build/fptest_pythia.yaml
python /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/run_lm_ablations.py /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/fptest_ablate_pythia.yaml
python /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/plot_lm_ablations.py /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/out/debug-pythia-14m_ablation_results.json -f
nix-apollo commented 11 months ago

Is there a not-terribly-expensive test or two we can add from this experience? I'm thinking things like:

  1. running the experiment in float 32 (except for mdash) gives the same result as running with float 64
  2. the ablation curves are flat in the way they should be
stefan-apollo commented 11 months ago

I'll implement these tests together with Nix ~tomorrow

stefan-apollo commented 11 months ago

Test fails on CPU -- giving up, only running the non-ablation tests on GPU for now.

INFO     root:run_lm_rib_build.py:304 Time to load model and dataset: 11.03
INFO     root:run_lm_rib_build.py:331 Collecting gram matrices for 1 batches.
INFO     root:run_lm_rib_build.py:342 Time to collect gram matrices: 0.93
INFO     root:run_lm_rib_build.py:348 Calculating interaction rotations (Cs).
INFO     root:run_lm_rib_build.py:365 Time to calculate Cs: 0.6 minutes
INFO     root:run_lm_rib_build.py:371 Skipping edge calculation.
INFO     root:run_lm_rib_build.py:441 Saved results to /tmp/tmpm3rc6_3k/float-precision-test-pythia-14m-float32_rib_Cs.pt
INFO     root:run_lm_rib_build.py:304 Time to load model and dataset: 2.57
INFO     root:run_lm_rib_build.py:331 Collecting gram matrices for 1 batches.
INFO     root:run_lm_rib_build.py:342 Time to collect gram matrices: 1.37
INFO     root:run_lm_rib_build.py:348 Calculating interaction rotations (Cs).
INFO     root:run_lm_rib_build.py:365 Time to calculate Cs: 1.0 minutes
INFO     root:run_lm_rib_build.py:371 Skipping edge calculation.
INFO     root:run_lm_rib_build.py:441 Saved results to /tmp/tmpm3rc6_3k/float-precision-test-pythia-14m-float64_rib_Cs.pt
============================================================================================================================================================================================================ short test summary info ============================================================================================================================================================================================================
FAILED tests/test_float_precision.py::test_pythia_floating_point_errors - AssertionError: 1
================================================================================================================================================================================================= 1 failed, 71 deselected in 122.60s (0:02:02) ==================================================================================================================================================================================================
stefan-apollo commented 11 months ago

Okay pytest --runslow -k test_pythia_floating_point_errors runs on CPU now, taking ~2-3 minutes.

stefan-apollo commented 11 months ago

While all tests pass on GPU, the float32 ablations break on GPU.

    @pytest.mark.parametrize("dtype", ["float32", "float64"])
    def test_ablation_result_flatness(self, ablation_results: dict, dtype: str) -> None:
        for node_layer in ablation_results["float32"].keys():
            if "mlp_out" in node_layer:
                # Should be identical due to residual stream size
>               ablation_result_128 = ablation_results[dtype][node_layer]["128"]
E               AssertionError: MLP non-flat ablation curve float32 mlp_out.0: 3.7895944118499756 (128) != 3.77636981010437 (642)
E               assert False
E                +  where False = <built-in method allclose of type object at 0x7f55fe59c540>(tensor(3.7896), tensor(3.7764), atol=0.001)
E                +    where <built-in method allclose of type object at 0x7f55fe59c540> = torch.allclose
E                +    and   tensor(3.7896) = <built-in method tensor of type object at 0x7f55fe59c540>(3.7895944118499756)
E                +      where <built-in method tensor of type object at 0x7f55fe59c540> = torch.tensor
E                +    and   tensor(3.7764) = <built-in method tensor of type object at 0x7f55fe59c540>(3.77636981010437)
E                +      where <built-in method tensor of type object at 0x7f55fe59c540> = torch.tensor

tests/test_float_precision.py:204: AssertionError
====================================================================================================== short test summary info =======================================================================================================
FAILED tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_float_precision - AssertionError: Float difference mlp_out.0 128: 3.7895944118499756 (float32) != 3.7763756462461164 (float32)
FAILED tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_flatness[float32] - AssertionError: MLP non-flat ablation curve float32 mlp_out.0: 3.7895944118499756 (128) != 3.77636981010437 (642)
========================================================================================= 2 failed, 3 passed, 1 skipped in 214.94s (0:03:34) =========================================================================================

This is probably due to slightly different C matrices, rather than ablation_config being with float32, but will test. Edit: Confirmed, ablation_config["dtype"] does not matter.

stefan-apollo commented 11 months ago

These ablation curves errors are (unlike the test_interaction_rotations ones we had, skipping now) not related to batch size and do not occur on GPU even when

            batch_size: 1  #  A100 can handle 24
            gram_batch_size: 1  #  A100 can handle 80
stefan-apollo commented 11 months ago

The CPU tests do pass if I calculate and accumulate the Lambda matrixes in float64. I had changed

stefan-apollo commented 11 months ago

Now I tested:

Okay, so looks like einsum on CPU was the issue.

stefan-apollo commented 11 months ago

I manually confirmed that the tests pass with different seeds (tried 3 different seeds on GPU)

stefan-apollo commented 11 months ago

Also, the test which takes ~2 min on CPU devbox takes ~24min on CI runner.

nix-apollo commented 11 months ago

Also, the test which takes ~2 min on CPU devbox takes ~24min on CI runner.

Seems pretty sad to have ci go from ~7 to ~30 mins! Is it possible to adjust the config to have it run faster? Or do we just need to --skip-ci this test until we have gpu runners (#170 )

stefan-apollo commented 11 months ago

Tests continue to fail on the CI while they work on CPU and GPU devboxes. The float32 ablation curve is (a) not flat and (b) different from the float64 one.

{'642': 3.776362737019857, '321': 3.7787582079569497, '128': 3.77876877784729, '0': 10.825837135314941} (float32)
!=
{'642': 3.7763756431303417, '321': 3.7763756431306406, '128': 3.776375642827917, '0': 10.825839875788878} (float64)
stefan-apollo commented 11 months ago

Nix: Based on the output it seems it's a different test that takes long, but I agree this makes no sense to me and my test could totally take long. Screenshot 2023-11-22 at 14 14 26

stefan-apollo commented 11 months ago

The output was just wrong, the actual slowest tests were the new ones as expected

============================= slowest 10 durations =============================
9.73s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_gram_matrices
88.86s call     tests/test_build_graph.py::test_pythia_14m_build_graph
38.82s call     tests/test_folded_bias.py::test_gpt2_folded_bias
30.09s call     tests/test_build_graph.py::test_mnist_build_graph
19.85s call     tests/test_folded_bias.py::test_pythia_folded_bias
14.28s call     tests/test_train_modular_arithmetic.py::test_main_accuracy
13.77s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph
9.50s call     tests/test_ablations.py::test_run_mnist_orthog_ablations
7.51s call     tests/test_train_mnist.py::test_main_accuracy
7.15s call     tests/test_ablations.py::test_run_modular_arithmetic_rib_ablations
=========== 71 passed, 1 skipped, 5 deselected in 596.81s (0:09:56) ============

Skipping those on CI from now on.

Edit: For posterity's sake, durations of both new tests:


============================= slowest 10 durations =============================
634.99s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_gram_matrices
382.80s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_float_precision
157.70s call     tests/test_build_graph.py::test_pythia_14m_build_graph
62.86s call     tests/test_build_graph.py::test_mnist_build_graph
54.36s call     tests/test_folded_bias.py::test_gpt2_folded_bias
24.69s call     tests/test_train_modular_arithmetic.py::test_main_accuracy
23.68s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph
19.62s call     tests/test_folded_bias.py::test_pythia_folded_bias
19.04s call     tests/test_ablations.py::test_run_mnist_orthog_ablations
14.99s call     tests/test_train_mnist.py::test_main_accuracy
=========================== short test summary info ============================
stefan-apollo commented 11 months ago

Looks like we're not the only ones with dtype related issues on GitHub CI: https://opensourcemechanistic.slack.com/archives/C04SRRE96UV/p1700313374803079