ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Graphs not invariant under `rotate_final_node_layer` #241

Open stefan-apollo opened 10 months ago

stefan-apollo commented 10 months ago

According to the math, an orthogonal rotation (like rotate_final_node_layer) should not affect RIB bases in any of the other layers. However it does affect the edges, specifically the truncation for

Curiously it works at the other float levels. At these specific thresholds set in the tests. We should test more thresholds.

Not sure how much of this can be explained by floating point errors, but that may certainly be the cause. This means fp errors somewhat affect fp64.

Associated code in fix/rotation_invariance

FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-alpha)^2-functional-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-0)*alpha-functional-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-alpha)^2-squared-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-0)*alpha-squared-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-alpha)^2-functional-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 478])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-0)*alpha-functional-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 481])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-alpha)^2-squared-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 478])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-0)*alpha-squared-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 481])

def rotate_final_layer_invariance(
    config_str_rotated: str,
    config_cls: Union["LMRibConfig", "MlpRibConfig"],
    build_graph_main_fn: Callable,
    rtol: float = 1e-7,
    atol: float = 0,
):
    config_str_not_rotated = config_str_rotated.replace(
        "rotate_final_node_layer: true", "rotate_final_node_layer: false"
    )

    config_rotated = config_cls(**yaml.safe_load(config_str_rotated))
    config_not_rotated = config_cls(**yaml.safe_load(config_str_not_rotated))

    edges_rotated = build_graph_main_fn(config_rotated)["edges"]
    edges_not_rotated = build_graph_main_fn(config_not_rotated)["edges"]

    # -1 has no edges, -2 is the final layer and changes
    print("Node layers", config_rotated.node_layers)
    comparison_layers = config_rotated.node_layers[:-2]
    for i, module_name in enumerate(comparison_layers):
        # E_hats[i] is a tuple (name, tensor)
        print("Comparing", module_name)
        # Check shape
        assert (
            edges_not_rotated[i][1].shape == edges_rotated[i][1].shape
        ), f"edges_not_rotated and edges_rotated not same shape for {module_name}, got shapes {edges_not_rotated[i][1].shape} and {edges_rotated[i][1].shape}"
        # Check values
        assert torch.allclose(
            edges_not_rotated[i][1],
            edges_rotated[i][1],
            rtol=rtol,
            atol=atol,
        ), f"edges_not_rotated not equal to shape of edges_rotated for {module_name}. Biggest relative deviation: {(edges_not_rotated[i][1] / edges_rotated[i][1]).min()}, {(edges_not_rotated[i][1] / edges_rotated[i][1]).max()}"
        # except AssertionError as e:
        #     import matplotlib.pyplot as plt

        # plt.scatter(edges_not_rotated[i][1], edges_rotated[i][1])
        # plt.axhline(atol, color="red")
        # plt.axvline(atol, color="red")
        # # straight line rtol
        # edge_min = min(edges_not_rotated[i][1].min(), edges_rotated[i][1].min())
        # edge_max = max(edges_not_rotated[i][1].max(), edges_rotated[i][1].max())
        # plt.plot([edge_min, edge_max], [edge_min, edge_max], color="red")
        # # rtol
        # plt.plot([edge_min, edge_max], [edge_min * rtol, edge_max * rtol], color="green")
        # plt.plot([edge_min, edge_max], [edge_min / rtol, edge_max / rtol], color="green")
        # plt.savefig(f"tests/rotate_final_layer_invariance_{module_name}.png")

@pytest.mark.slow
@pytest.mark.parametrize(
    "basis_formula, edge_formula, dtype_str",
    [
        ("(1-alpha)^2", "functional", "float32"),
        ("(1-0)*alpha", "functional", "float32"),
        ("(1-alpha)^2", "functional", "float64"),
        ("(1-0)*alpha", "functional", "float64"),
        ("(1-alpha)^2", "squared", "float32"),
        ("(1-0)*alpha", "squared", "float32"),
        ("(1-alpha)^2", "squared", "float64"),
        ("(1-0)*alpha", "squared", "float64"),
    ],
)
def test_mnist_rotate_final_layer_invariance(
    basis_formula, edge_formula, dtype_str, rtol=1e-7, atol=1e-8
):
    """Test that the non-final edges are the same for MNIST whether or not we rotate the final layer."""
    config_str_rotated = f"""
    exp_name: test
    mlp_path: experiments/train_mlp/sample_checkpoints/lr-0.001_bs-64_2023-11-29_14-36-29/model_epoch_12.pt
    batch_size: 256
    seed: 0
    truncation_threshold: 1e-6
    rotate_final_node_layer: true  # Gets overridden by rotate_final_layer_invariance
    n_intervals: 0
    dtype: float64 # in float32 the truncation changes between both runs
    dataset:
        return_set_frac: 0.01  # 3 batches (with batch_size=256)
    node_layers:
    - layers.1
    - layers.2
    - output
    out_dir: null
    dtype: {dtype_str}
    basis_formula: "{basis_formula}"
    edge_formula: "{edge_formula}"
    """

    rotate_final_layer_invariance(
        config_str_rotated=config_str_rotated,
        config_cls=MlpRibConfig,
        build_graph_main_fn=mlp_build_graph_main,
        rtol=rtol,
        atol=atol,
    )

@pytest.mark.slow
@pytest.mark.parametrize(
    "basis_formula, edge_formula, dtype_str",
    [
        # functional fp32 currently fails with these tolerances
        ("(1-alpha)^2", "functional", "float32"),
        ("(1-0)*alpha", "functional", "float32"),
        ("(1-alpha)^2", "functional", "float64"),
        ("(1-0)*alpha", "functional", "float64"),
        ("(1-alpha)^2", "squared", "float32"),
        ("(1-0)*alpha", "squared", "float32"),
        ("(1-alpha)^2", "squared", "float64"),
        ("(1-0)*alpha", "squared", "float64"),
    ],
)
def test_modular_arithmetic_rotate_final_layer_invariance(
    basis_formula,
    edge_formula,
    dtype_str,
    rtol=1e-3,
    atol=1e-3,
):
    """Test that the non-final edges are independent of final layer rotation for modadd.

    Note that atol is necessary as the less important edges do deviate. The largest edges are
    between 1e3 and 1e5 large.
    """
    config_str_rotated = f"""
    exp_name: test
    seed: 0
    tlens_pretrained: null
    tlens_model_path: experiments/train_modular_arithmetic/sample_checkpoints/lr-0.001_bs-10000_norm-None_2023-11-28_16-07-19/model_epoch_60000.pt
    dataset:
        source: custom
        name: modular_arithmetic
        return_set: train
        return_set_frac: null
        return_set_n_samples: 10
    node_layers:
        - mlp_out.0
        - unembed
        - output
    batch_size: 6
    gram_batch_size: 6
    edge_batch_size: 6
    truncation_threshold: 1e-15
    rotate_final_node_layer: true  # Gets overridden by rotate_final_layer_invariance
    last_pos_module_type: add_resid1
    n_intervals: 2
    dtype: {dtype_str}
    eval_type: accuracy
    out_dir: null
    basis_formula: "{basis_formula}"
    edge_formula: "{edge_formula}"
    """
    rotate_final_layer_invariance(
        config_str_rotated=config_str_rotated,
        config_cls=LMRibConfig,
        build_graph_main_fn=lm_build_graph_main,
        rtol=rtol,
        atol=atol,
    )