ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Investigate float precision issues in RIB #135

Open danbraunai opened 11 months ago

danbraunai commented 11 months ago

We currently set the dype to float32 for pythia in lm_rib_build and lm_ablations. Check whether we can get away with using bfloat16 for a bunch of the computation.

Note that the eigendecomposition happens in float64, which seemed necessary in previous experiments with modular_arithmetic, but this should be checked for lms.

danbraunai-apollo commented 10 months ago

Worse than this, we need to check why float32 seems to give different results to float64 (see https://github.com/ApolloResearch/rib/pull/185#issuecomment-1810637999)

Better than doing a deep dive into all this, it might be better to just improve our normalization so that we constrain all of our values. Then we can check whether there are precision differences.

stefan-apollo commented 10 months ago

Applies to mod add as well, renamed the issue

nix-apollo commented 10 months ago

I have noticeable floating point error in this test comparing 1 and 2 process edge calculations. Especially when run with float32.

nix-apollo commented 10 months ago

Running with different batch sizes also gives (slightly) different outputs, even with float 64. Consider the following code:

@pytest.mark.slow
def test_batch_size_calc_gives_same_edges():
    rib_dir = str(Path(__file__).parent.parent)

    def make_config(name: str, temp_dir: str, bs:int):
        config_str = f"""
        exp_name: {name}
        seed: 0
        tlens_pretrained: null
        tlens_model_path: {rib_dir}/experiments/train_modular_arithmetic/sample_checkpoints/lr-0.001_bs-10000_norm-None_2023-09-27_18-19-33/model_epoch_60000.pt
        node_layers:
            - ln1.0
            - mlp_in.0
            - unembed
            - output
        dataset:
            source: custom
            name: modular_arithmetic
            return_set: train
        batch_size: {bs}
        edge_batch_size: 1024
        truncation_threshold: 1e-6
        rotate_final_node_layer: false
        last_pos_module_type: add_resid1
        n_intervals: 0
        dtype: float64
        eval_type: accuracy
        out_dir: {temp_dir}
        """
        config_path = f"{temp_dir}/{name}.yaml"
        with open(config_path, "w") as f:
            f.write(config_str)
        return config_path

    run_file = rib_dir + "/experiments/lm_rib_build/run_lm_rib_build.py"

    with tempfile.TemporaryDirectory() as temp_dir:
        single_config_path = make_config("test_single", temp_dir, 512)
        double_config_path = make_config("test_double", temp_dir, 128)
        subprocess.run(["python", run_file, single_config_path], capture_output=True)
        print("done with single!")
        subprocess.run(["python", run_file, double_config_path], capture_output=True)

        print("done with double!")

        single_edges = torch.load(f"{temp_dir}/test_single_rib_graph.pt")["edges"]
        double_edges = torch.load(f"{temp_dir}/test_double_rib_graph.pt")["edges"]

        for (module, s_edges), (_, d_edges) in zip(single_edges, double_edges):
            assert torch.allclose(s_edges, d_edges, atol=1e-10), (module, s_edges, d_edges)

try:
    test_batch_size_calc_gives_same_edges()
except AssertionError as e:
    mod, a, b = e.args[0]
    print("Assert fail on", mod)

    print("mean abs diff", (a-b).abs().mean().item())
    print("max abs diff", (a-b).abs().max().item())

Which outputs:

done with single!
done with double!
Assert fail on ln1.0
mean abs diff 1.197769426592479e-07
max abs diff 6.682997782339766e-05

I'm running this on #196 which is forked from #189.

danbraunai-apollo commented 10 months ago

This branch contains code used for debugging difference in attention layers with n_ctx=2048 and a smaller n_ctx. A discussion thread is here.

The solution was actually just that we need IGNORE to be smaller than -1e5 for pythia-14m. Though there was a weird bug that, when changing that number in the "register_buffer" call, it doesn't actually change the value that got used.

danbraunai-apollo commented 9 months ago

Much of this will hopefully be fixed by #222. But I will leave this thread up and reduce the priority to low.

danbraunai-apollo commented 9 months ago

Right now, we're just defaulting to using float64 everywhere except for toy models. We will investigate precision errors later when scaling up.

Note that it is possible to use numpy float128 on a cpu to compare with float64, but it would be a fair bit of work to change all of our torch operations.