Closed stefan-apollo closed 11 months ago
Is there a not-terribly-expensive test or two we can add from this experience? I'm thinking things like:
I'll implement these tests together with Nix ~tomorrow
Test fails on CPU -- giving up, only running the non-ablation tests on GPU for now.
INFO root:run_lm_rib_build.py:304 Time to load model and dataset: 11.03
INFO root:run_lm_rib_build.py:331 Collecting gram matrices for 1 batches.
INFO root:run_lm_rib_build.py:342 Time to collect gram matrices: 0.93
INFO root:run_lm_rib_build.py:348 Calculating interaction rotations (Cs).
INFO root:run_lm_rib_build.py:365 Time to calculate Cs: 0.6 minutes
INFO root:run_lm_rib_build.py:371 Skipping edge calculation.
INFO root:run_lm_rib_build.py:441 Saved results to /tmp/tmpm3rc6_3k/float-precision-test-pythia-14m-float32_rib_Cs.pt
INFO root:run_lm_rib_build.py:304 Time to load model and dataset: 2.57
INFO root:run_lm_rib_build.py:331 Collecting gram matrices for 1 batches.
INFO root:run_lm_rib_build.py:342 Time to collect gram matrices: 1.37
INFO root:run_lm_rib_build.py:348 Calculating interaction rotations (Cs).
INFO root:run_lm_rib_build.py:365 Time to calculate Cs: 1.0 minutes
INFO root:run_lm_rib_build.py:371 Skipping edge calculation.
INFO root:run_lm_rib_build.py:441 Saved results to /tmp/tmpm3rc6_3k/float-precision-test-pythia-14m-float64_rib_Cs.pt
============================================================================================================================================================================================================ short test summary info ============================================================================================================================================================================================================
FAILED tests/test_float_precision.py::test_pythia_floating_point_errors - AssertionError: 1
================================================================================================================================================================================================= 1 failed, 71 deselected in 122.60s (0:02:02) ==================================================================================================================================================================================================
Okay pytest --runslow -k test_pythia_floating_point_errors
runs on CPU now, taking ~2-3 minutes.
While all tests pass on GPU, the float32 ablations break on GPU.
@pytest.mark.parametrize("dtype", ["float32", "float64"])
def test_ablation_result_flatness(self, ablation_results: dict, dtype: str) -> None:
for node_layer in ablation_results["float32"].keys():
if "mlp_out" in node_layer:
# Should be identical due to residual stream size
> ablation_result_128 = ablation_results[dtype][node_layer]["128"]
E AssertionError: MLP non-flat ablation curve float32 mlp_out.0: 3.7895944118499756 (128) != 3.77636981010437 (642)
E assert False
E + where False = <built-in method allclose of type object at 0x7f55fe59c540>(tensor(3.7896), tensor(3.7764), atol=0.001)
E + where <built-in method allclose of type object at 0x7f55fe59c540> = torch.allclose
E + and tensor(3.7896) = <built-in method tensor of type object at 0x7f55fe59c540>(3.7895944118499756)
E + where <built-in method tensor of type object at 0x7f55fe59c540> = torch.tensor
E + and tensor(3.7764) = <built-in method tensor of type object at 0x7f55fe59c540>(3.77636981010437)
E + where <built-in method tensor of type object at 0x7f55fe59c540> = torch.tensor
tests/test_float_precision.py:204: AssertionError
====================================================================================================== short test summary info =======================================================================================================
FAILED tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_float_precision - AssertionError: Float difference mlp_out.0 128: 3.7895944118499756 (float32) != 3.7763756462461164 (float32)
FAILED tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_flatness[float32] - AssertionError: MLP non-flat ablation curve float32 mlp_out.0: 3.7895944118499756 (128) != 3.77636981010437 (642)
========================================================================================= 2 failed, 3 passed, 1 skipped in 214.94s (0:03:34) =========================================================================================
This is probably due to slightly different C matrices, rather than ablation_config being with float32, but will test. Edit: Confirmed, ablation_config["dtype"] does not matter.
These ablation curves errors are (unlike the test_interaction_rotations
ones we had, skipping now) not related to batch size and do not occur on GPU even when
batch_size: 1 # A100 can handle 24
gram_batch_size: 1 # A100 can handle 80
The CPU tests do pass if I calculate and accumulate the Lambda matrixes in float64. I had changed
Now I tested:
Okay, so looks like einsum on CPU was the issue.
I manually confirmed that the tests pass with different seeds (tried 3 different seeds on GPU)
Also, the test which takes ~2 min on CPU devbox takes ~24min on CI runner.
Also, the test which takes ~2 min on CPU devbox takes ~24min on CI runner.
Seems pretty sad to have ci go from ~7 to ~30 mins! Is it possible to adjust the config to have it run faster? Or do we just need to --skip-ci this test until we have gpu runners (#170 )
Tests continue to fail on the CI while they work on CPU and GPU devboxes. The float32 ablation curve is (a) not flat and (b) different from the float64 one.
{'642': 3.776362737019857, '321': 3.7787582079569497, '128': 3.77876877784729, '0': 10.825837135314941} (float32)
!=
{'642': 3.7763756431303417, '321': 3.7763756431306406, '128': 3.776375642827917, '0': 10.825839875788878} (float64)
Nix: Based on the output it seems it's a different test that takes long, but I agree this makes no sense to me and my test could totally take long.
The output was just wrong, the actual slowest tests were the new ones as expected
============================= slowest 10 durations =============================
9.73s setup tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_gram_matrices
88.86s call tests/test_build_graph.py::test_pythia_14m_build_graph
38.82s call tests/test_folded_bias.py::test_gpt2_folded_bias
30.09s call tests/test_build_graph.py::test_mnist_build_graph
19.85s call tests/test_folded_bias.py::test_pythia_folded_bias
14.28s call tests/test_train_modular_arithmetic.py::test_main_accuracy
13.77s call tests/test_build_graph.py::test_modular_arithmetic_build_graph
9.50s call tests/test_ablations.py::test_run_mnist_orthog_ablations
7.51s call tests/test_train_mnist.py::test_main_accuracy
7.15s call tests/test_ablations.py::test_run_modular_arithmetic_rib_ablations
=========== 71 passed, 1 skipped, 5 deselected in 596.81s (0:09:56) ============
Skipping those on CI from now on.
Edit: For posterity's sake, durations of both new tests:
============================= slowest 10 durations =============================
634.99s setup tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_gram_matrices
382.80s setup tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_float_precision
157.70s call tests/test_build_graph.py::test_pythia_14m_build_graph
62.86s call tests/test_build_graph.py::test_mnist_build_graph
54.36s call tests/test_folded_bias.py::test_gpt2_folded_bias
24.69s call tests/test_train_modular_arithmetic.py::test_main_accuracy
23.68s call tests/test_build_graph.py::test_modular_arithmetic_build_graph
19.62s call tests/test_folded_bias.py::test_pythia_folded_bias
19.04s call tests/test_ablations.py::test_run_mnist_orthog_ablations
14.99s call tests/test_train_mnist.py::test_main_accuracy
=========================== short test summary info ============================
Looks like we're not the only ones with dtype related issues on GitHub CI: https://opensourcemechanistic.slack.com/archives/C04SRRE96UV/p1700313374803079
Description
Based on feature/force_overwrite_output, merge that before this PR (Merged)
Hopefully fixes remaining floating point issues.Fixes two of our floating point issues most of the time.
Tested
Implemented tests that
Future work:
Also manually tested by observing that ablation curves stay flat until 128 with these configs: Build
Ablate
Result: The results were also tested on a larger dataset of
return_set_frac: 0.1
.PS: My debugging run script