According to the math, an orthogonal rotation (like rotate_final_node_layer) should not affect RIB bases in any of the other layers. However it does affect the edges, specifically the truncation for
MNIST at float32
Mod Add at float64
Curiously it works at the other float levels. At these specific thresholds set in the tests. We should test more thresholds.
Not sure how much of this can be explained by floating point errors, but that may certainly be the cause. This means fp errors somewhat affect fp64.
Associated code in fix/rotation_invariance
FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-alpha)^2-functional-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-0)*alpha-functional-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-alpha)^2-squared-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_mnist_rotate_final_layer_invariance[(1-0)*alpha-squared-float32] - AssertionError: edges_not_rotated and edges_rotated not same shape for layers.1, got shapes torch.Size([40, 96]) and torch.Size([48, 96])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-alpha)^2-functional-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 478])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-0)*alpha-functional-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 481])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-alpha)^2-squared-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 478])
FAILED tests/test_build_graph.py::test_modular_arithmetic_rotate_final_layer_invariance[(1-0)*alpha-squared-float64] - AssertionError: edges_not_rotated and edges_rotated not same shape for mlp_out.0, got shapes torch.Size([126, 483]) and torch.Size([126, 481])
def rotate_final_layer_invariance(
config_str_rotated: str,
config_cls: Union["LMRibConfig", "MlpRibConfig"],
build_graph_main_fn: Callable,
rtol: float = 1e-7,
atol: float = 0,
):
config_str_not_rotated = config_str_rotated.replace(
"rotate_final_node_layer: true", "rotate_final_node_layer: false"
)
config_rotated = config_cls(**yaml.safe_load(config_str_rotated))
config_not_rotated = config_cls(**yaml.safe_load(config_str_not_rotated))
edges_rotated = build_graph_main_fn(config_rotated)["edges"]
edges_not_rotated = build_graph_main_fn(config_not_rotated)["edges"]
# -1 has no edges, -2 is the final layer and changes
print("Node layers", config_rotated.node_layers)
comparison_layers = config_rotated.node_layers[:-2]
for i, module_name in enumerate(comparison_layers):
# E_hats[i] is a tuple (name, tensor)
print("Comparing", module_name)
# Check shape
assert (
edges_not_rotated[i][1].shape == edges_rotated[i][1].shape
), f"edges_not_rotated and edges_rotated not same shape for {module_name}, got shapes {edges_not_rotated[i][1].shape} and {edges_rotated[i][1].shape}"
# Check values
assert torch.allclose(
edges_not_rotated[i][1],
edges_rotated[i][1],
rtol=rtol,
atol=atol,
), f"edges_not_rotated not equal to shape of edges_rotated for {module_name}. Biggest relative deviation: {(edges_not_rotated[i][1] / edges_rotated[i][1]).min()}, {(edges_not_rotated[i][1] / edges_rotated[i][1]).max()}"
# except AssertionError as e:
# import matplotlib.pyplot as plt
# plt.scatter(edges_not_rotated[i][1], edges_rotated[i][1])
# plt.axhline(atol, color="red")
# plt.axvline(atol, color="red")
# # straight line rtol
# edge_min = min(edges_not_rotated[i][1].min(), edges_rotated[i][1].min())
# edge_max = max(edges_not_rotated[i][1].max(), edges_rotated[i][1].max())
# plt.plot([edge_min, edge_max], [edge_min, edge_max], color="red")
# # rtol
# plt.plot([edge_min, edge_max], [edge_min * rtol, edge_max * rtol], color="green")
# plt.plot([edge_min, edge_max], [edge_min / rtol, edge_max / rtol], color="green")
# plt.savefig(f"tests/rotate_final_layer_invariance_{module_name}.png")
@pytest.mark.slow
@pytest.mark.parametrize(
"basis_formula, edge_formula, dtype_str",
[
("(1-alpha)^2", "functional", "float32"),
("(1-0)*alpha", "functional", "float32"),
("(1-alpha)^2", "functional", "float64"),
("(1-0)*alpha", "functional", "float64"),
("(1-alpha)^2", "squared", "float32"),
("(1-0)*alpha", "squared", "float32"),
("(1-alpha)^2", "squared", "float64"),
("(1-0)*alpha", "squared", "float64"),
],
)
def test_mnist_rotate_final_layer_invariance(
basis_formula, edge_formula, dtype_str, rtol=1e-7, atol=1e-8
):
"""Test that the non-final edges are the same for MNIST whether or not we rotate the final layer."""
config_str_rotated = f"""
exp_name: test
mlp_path: experiments/train_mlp/sample_checkpoints/lr-0.001_bs-64_2023-11-29_14-36-29/model_epoch_12.pt
batch_size: 256
seed: 0
truncation_threshold: 1e-6
rotate_final_node_layer: true # Gets overridden by rotate_final_layer_invariance
n_intervals: 0
dtype: float64 # in float32 the truncation changes between both runs
dataset:
return_set_frac: 0.01 # 3 batches (with batch_size=256)
node_layers:
- layers.1
- layers.2
- output
out_dir: null
dtype: {dtype_str}
basis_formula: "{basis_formula}"
edge_formula: "{edge_formula}"
"""
rotate_final_layer_invariance(
config_str_rotated=config_str_rotated,
config_cls=MlpRibConfig,
build_graph_main_fn=mlp_build_graph_main,
rtol=rtol,
atol=atol,
)
@pytest.mark.slow
@pytest.mark.parametrize(
"basis_formula, edge_formula, dtype_str",
[
# functional fp32 currently fails with these tolerances
("(1-alpha)^2", "functional", "float32"),
("(1-0)*alpha", "functional", "float32"),
("(1-alpha)^2", "functional", "float64"),
("(1-0)*alpha", "functional", "float64"),
("(1-alpha)^2", "squared", "float32"),
("(1-0)*alpha", "squared", "float32"),
("(1-alpha)^2", "squared", "float64"),
("(1-0)*alpha", "squared", "float64"),
],
)
def test_modular_arithmetic_rotate_final_layer_invariance(
basis_formula,
edge_formula,
dtype_str,
rtol=1e-3,
atol=1e-3,
):
"""Test that the non-final edges are independent of final layer rotation for modadd.
Note that atol is necessary as the less important edges do deviate. The largest edges are
between 1e3 and 1e5 large.
"""
config_str_rotated = f"""
exp_name: test
seed: 0
tlens_pretrained: null
tlens_model_path: experiments/train_modular_arithmetic/sample_checkpoints/lr-0.001_bs-10000_norm-None_2023-11-28_16-07-19/model_epoch_60000.pt
dataset:
source: custom
name: modular_arithmetic
return_set: train
return_set_frac: null
return_set_n_samples: 10
node_layers:
- mlp_out.0
- unembed
- output
batch_size: 6
gram_batch_size: 6
edge_batch_size: 6
truncation_threshold: 1e-15
rotate_final_node_layer: true # Gets overridden by rotate_final_layer_invariance
last_pos_module_type: add_resid1
n_intervals: 2
dtype: {dtype_str}
eval_type: accuracy
out_dir: null
basis_formula: "{basis_formula}"
edge_formula: "{edge_formula}"
"""
rotate_final_layer_invariance(
config_str_rotated=config_str_rotated,
config_cls=LMRibConfig,
build_graph_main_fn=lm_build_graph_main,
rtol=rtol,
atol=atol,
)
According to the math, an orthogonal rotation (like
rotate_final_node_layer
) should not affect RIB bases in any of the other layers. However it does affect the edges, specifically the truncation forCuriously it works at the other float levels. At these specific thresholds set in the tests. We should test more thresholds.
Not sure how much of this can be explained by floating point errors, but that may certainly be the cause. This means fp errors somewhat affect fp64.
Associated code in
fix/rotation_invariance