ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Replace "scaler" and "interval_size" by general weights #226

Open stefan-apollo opened 10 months ago

stefan-apollo commented 10 months ago

Weights as a general way to specify weights for each integration step. The current way is clunky and limited to trapezoidal rule, do this when we implement different integration. I think it should be roughly like shown here though this fails some tests. Not urgent, do later.

New function

def _calc_integration_intervals(
    n_intervals: int,
    integral_boundary_relative_epsilon: float = 1e-3,
) -> tuple[np.ndarray, np.ndarray]:
    """Calculate the integration steps for n_intervals between 0+eps and 1-eps.

    Args:
        n_intervals: The number of intervals to use for the integral approximation. If 0, take a
            point estimate at alpha=0.5 instead of using the trapezoidal rule.
        integral_boundary_relative_epsilon: Rather than integrating from 0 to 1, we integrate from
            integral_boundary_epsilon to 1 - integral_boundary_epsilon, to avoid issues with
            ill-defined derivatives at 0 and 1.
            integral_boundary_epsilon = integral_boundary_relative_epsilon/(n_intervals+1).

    Returns:
        alphas: The integration steps.
        weights: The weights for each integration step, accounting for integral size,
            integral_boundary_epsilon, and the trapezoidal rule (0.5 for the endpoints).
    """
    # Scale accuracy of the integral boundaries with the number of intervals
    integral_boundary_epsilon = integral_boundary_relative_epsilon / (n_intervals + 1)
    # Integration samples
    if n_intervals == 0:
        alphas = np.array([0.5])
        weights = np.array([1])
        n_alphas = 1
    else:
        # Integration steps for n_intervals intervals
        n_alphas = n_intervals + 1
        alphas = np.linspace(integral_boundary_epsilon, 1 - integral_boundary_epsilon, n_alphas)
        assert np.allclose(np.diff(alphas), alphas[1] - alphas[0]), "alphas must be equally spaced."
        # Multiply the interval sizes by (1 + 2 eps) to balance out the smaller integration interval
        interval_size = (alphas[1] - alphas[0]) / (1 - 2 * integral_boundary_epsilon)
        weights = np.ones_like(alphas) * interval_size
        # As per the trapezoidal rule, multiply the endpoints by 1/2
        weights[0] *= 0.5
        weights[-1] *= 0.5
        assert np.allclose(
            n_intervals * interval_size,
            1,
        ), f"n_intervals * interval_size ({n_intervals * interval_size}) != 1"
        assert np.allclose(weights.sum(), 1), f"weights.sum() ({weights.sum()}) != 1"
    return alphas, weights

Diff file

diff --git a/rib/linalg.py b/rib/linalg.py
index 5bf4379..71d0691 100644
--- a/rib/linalg.py
+++ b/rib/linalg.py
@@ -142,7 +142,7 @@ def pinv_diag(x: Float[Tensor, "a a"]) -> Float[Tensor, "a a"]:
 def _calc_integration_intervals(
     n_intervals: int,
     integral_boundary_relative_epsilon: float = 1e-3,
-) -> tuple[np.ndarray, float]:
+) -> tuple[np.ndarray, np.ndarray]:
     """Calculate the integration steps for n_intervals between 0+eps and 1-eps.

     Args:
@@ -155,15 +155,15 @@ def _calc_integration_intervals(

     Returns:
         alphas: The integration steps.
-        interval_size: The size of each integration step, including a correction factor to account
-            for integral_boundary_epsilon.
+        weights: The weights for each integration step, accounting for integral size,
+            integral_boundary_epsilon, and the trapezoidal rule (0.5 for the endpoints).
     """
     # Scale accuracy of the integral boundaries with the number of intervals
     integral_boundary_epsilon = integral_boundary_relative_epsilon / (n_intervals + 1)
     # Integration samples
     if n_intervals == 0:
         alphas = np.array([0.5])
-        interval_size = 1.0
+        weights = np.array([1])
         n_alphas = 1
     else:
         # Integration steps for n_intervals intervals
@@ -172,11 +172,16 @@ def _calc_integration_intervals(
         assert np.allclose(np.diff(alphas), alphas[1] - alphas[0]), "alphas must be equally spaced."
         # Multiply the interval sizes by (1 + 2 eps) to balance out the smaller integration interval
         interval_size = (alphas[1] - alphas[0]) / (1 - 2 * integral_boundary_epsilon)
+        weights = np.ones_like(alphas) * interval_size
+        # As per the trapezoidal rule, multiply the endpoints by 1/2
+        weights[0] *= 0.5
+        weights[-1] *= 0.5
         assert np.allclose(
             n_intervals * interval_size,
             1,
         ), f"n_intervals * interval_size ({n_intervals * interval_size}) != 1"
-    return alphas, interval_size
+        assert np.allclose(weights.sum(), 1), f"weights.sum() ({weights.sum()}) != 1"
+    return alphas, weights

 def integrated_gradient_trapezoidal_jacobian(
@@ -203,20 +208,17 @@ def integrated_gradient_trapezoidal_jacobian(
     f_in_hat.requires_grad_(True)

     # Prepare integral
-    alphas, interval_size = _calc_integration_intervals(
+    alphas, weights = _calc_integration_intervals(
         n_intervals, integral_boundary_relative_epsilon=1e-3
     )
     if edge_formula == "october":
         f_out_hat_const = module_hat(f_in_hat)
-        for alpha_index, alpha in tqdm(
-            enumerate(alphas), total=len(alphas), desc="Integration steps (alphas)", leave=False
+        for alpha_index, (alpha, weight) in tqdm(
+            enumerate(zip(alphas, weights)),
+            total=len(alphas),
+            desc="Integration steps (alphas)",
+            leave=False,
         ):
-            # As per the trapezoidal rule, multiply the endpoints by 1/2
-            # (unless we're taking a point estimate at alpha=0.5)
-            scaler = (
-                0.5 if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals) else 1
-            )
-
             einsum_pattern = "bpj,bpj->j" if f_in_hat.ndim == 3 else "bj,bj->j"
             # Normalize by the dataset size and the number of positions (if the input has a position dim)
             normalization_factor = f_in_hat.shape[1] * dataset_size if has_pos else dataset_size
@@ -247,8 +249,7 @@ def integrated_gradient_trapezoidal_jacobian(
                 i_grad = (
                     torch.autograd.grad(f_out_hat_norm[i], alpha_f_in_hat, retain_graph=True)[0]
                     / normalization_factor
-                    * interval_size
-                    * scaler
+                    * weight
                 )
                 with torch.inference_mode():
                     E = torch.einsum(einsum_pattern, i_grad, f_in_hat)
@@ -285,25 +286,16 @@ def integrated_gradient_trapezoidal_jacobian(
             else torch.zeros(batch_size, out_hidden_size_comb_trunc, in_hidden_size_comb_trunc)
         )
         # Integral
-        for alpha_index, alpha in tqdm(
-            enumerate(alphas), total=len(alphas), desc="Integration steps (alphas)", leave=False
+        for alpha_index, (alpha, weight) in tqdm(
+            enumerate(zip(alphas, weights)),
+            total=len(alphas),
+            desc="Integration steps (alphas)",
+            leave=False,
         ):
-            # As per the trapezoidal rule, multiply the endpoints by 1/2
-            # (unless we're taking a point estimate at alpha=0.5)
-            scaler = (
-                0.5 if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals) else 1
-            )
-            #
             # We have to compute inputs from f_hat to make autograd work
             alpha_f_in_hat = alpha * f_in_hat
             f_out_alpha_hat = module_hat(alpha_f_in_hat)

-            # As per the trapezoidal rule, multiply the endpoints by 1/2 (unless we're taking a point
-            # estimate at alpha=1)
-            scaler = (
-                0.5 if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals) else 1
-            )
-
             # Take the derivative of the (i, t) element (output dim and output pos) of the output
             # Note that t (output pos) is different from p (tprime, input pos)
             for out_dim in range(out_hidden_size_comb_trunc):
@@ -322,9 +314,7 @@ def integrated_gradient_trapezoidal_jacobian(

                     # Sum over tprime (p, input pos) as per Lucius' formula (A.18)
                     with torch.inference_mode():
-                        inner_token_sum = torch.einsum(
-                            einsum_pattern, grad_0 * interval_size * scaler, f_in_hat
-                        )
+                        inner_token_sum = torch.einsum(einsum_pattern, grad_0 * weight, f_in_hat)
                         # We have a minus sign in front of the IG integral, see e.g. the definition of g_j
                         # in equation (3.27)
                         inner_token_sums[:, token_index, out_dim, :] -= inner_token_sum.to(
@@ -395,11 +385,9 @@ def integrated_gradient_trapezoidal_norm(

     in_grads = torch.zeros_like(torch.cat(inputs, dim=-1))

-    alphas, interval_size = _calc_integration_intervals(
-        n_intervals, integral_boundary_relative_epsilon
-    )
+    alphas, weights = _calc_integration_intervals(n_intervals, integral_boundary_relative_epsilon)

-    for alpha_index, alpha in enumerate(alphas):
+    for alpha_index, (alpha, weight) in enumerate(zip(alphas, weights)):
         # Compute f^{l+1}(f^l(alpha x))
         alpha_inputs = tuple(alpha * x for x in inputs)
         output_alpha = module(*alpha_inputs)
@@ -443,7 +431,7 @@ def integrated_gradient_trapezoidal_norm(
         if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals):
             alpha_in_grads = 0.5 * alpha_in_grads

-        in_grads += alpha_in_grads * interval_size
+        in_grads += alpha_in_grads * weight

         for x in alpha_inputs:
             assert x.grad is not None, "Input grad should not be None."