Weights as a general way to specify weights for each integration step. The current way is clunky and limited to trapezoidal rule, do this when we implement different integration. I think it should be roughly like shown here though this fails some tests. Not urgent, do later.
New function
def _calc_integration_intervals(
n_intervals: int,
integral_boundary_relative_epsilon: float = 1e-3,
) -> tuple[np.ndarray, np.ndarray]:
"""Calculate the integration steps for n_intervals between 0+eps and 1-eps.
Args:
n_intervals: The number of intervals to use for the integral approximation. If 0, take a
point estimate at alpha=0.5 instead of using the trapezoidal rule.
integral_boundary_relative_epsilon: Rather than integrating from 0 to 1, we integrate from
integral_boundary_epsilon to 1 - integral_boundary_epsilon, to avoid issues with
ill-defined derivatives at 0 and 1.
integral_boundary_epsilon = integral_boundary_relative_epsilon/(n_intervals+1).
Returns:
alphas: The integration steps.
weights: The weights for each integration step, accounting for integral size,
integral_boundary_epsilon, and the trapezoidal rule (0.5 for the endpoints).
"""
# Scale accuracy of the integral boundaries with the number of intervals
integral_boundary_epsilon = integral_boundary_relative_epsilon / (n_intervals + 1)
# Integration samples
if n_intervals == 0:
alphas = np.array([0.5])
weights = np.array([1])
n_alphas = 1
else:
# Integration steps for n_intervals intervals
n_alphas = n_intervals + 1
alphas = np.linspace(integral_boundary_epsilon, 1 - integral_boundary_epsilon, n_alphas)
assert np.allclose(np.diff(alphas), alphas[1] - alphas[0]), "alphas must be equally spaced."
# Multiply the interval sizes by (1 + 2 eps) to balance out the smaller integration interval
interval_size = (alphas[1] - alphas[0]) / (1 - 2 * integral_boundary_epsilon)
weights = np.ones_like(alphas) * interval_size
# As per the trapezoidal rule, multiply the endpoints by 1/2
weights[0] *= 0.5
weights[-1] *= 0.5
assert np.allclose(
n_intervals * interval_size,
1,
), f"n_intervals * interval_size ({n_intervals * interval_size}) != 1"
assert np.allclose(weights.sum(), 1), f"weights.sum() ({weights.sum()}) != 1"
return alphas, weights
Diff file
diff --git a/rib/linalg.py b/rib/linalg.py
index 5bf4379..71d0691 100644
--- a/rib/linalg.py
+++ b/rib/linalg.py
@@ -142,7 +142,7 @@ def pinv_diag(x: Float[Tensor, "a a"]) -> Float[Tensor, "a a"]:
def _calc_integration_intervals(
n_intervals: int,
integral_boundary_relative_epsilon: float = 1e-3,
-) -> tuple[np.ndarray, float]:
+) -> tuple[np.ndarray, np.ndarray]:
"""Calculate the integration steps for n_intervals between 0+eps and 1-eps.
Args:
@@ -155,15 +155,15 @@ def _calc_integration_intervals(
Returns:
alphas: The integration steps.
- interval_size: The size of each integration step, including a correction factor to account
- for integral_boundary_epsilon.
+ weights: The weights for each integration step, accounting for integral size,
+ integral_boundary_epsilon, and the trapezoidal rule (0.5 for the endpoints).
"""
# Scale accuracy of the integral boundaries with the number of intervals
integral_boundary_epsilon = integral_boundary_relative_epsilon / (n_intervals + 1)
# Integration samples
if n_intervals == 0:
alphas = np.array([0.5])
- interval_size = 1.0
+ weights = np.array([1])
n_alphas = 1
else:
# Integration steps for n_intervals intervals
@@ -172,11 +172,16 @@ def _calc_integration_intervals(
assert np.allclose(np.diff(alphas), alphas[1] - alphas[0]), "alphas must be equally spaced."
# Multiply the interval sizes by (1 + 2 eps) to balance out the smaller integration interval
interval_size = (alphas[1] - alphas[0]) / (1 - 2 * integral_boundary_epsilon)
+ weights = np.ones_like(alphas) * interval_size
+ # As per the trapezoidal rule, multiply the endpoints by 1/2
+ weights[0] *= 0.5
+ weights[-1] *= 0.5
assert np.allclose(
n_intervals * interval_size,
1,
), f"n_intervals * interval_size ({n_intervals * interval_size}) != 1"
- return alphas, interval_size
+ assert np.allclose(weights.sum(), 1), f"weights.sum() ({weights.sum()}) != 1"
+ return alphas, weights
def integrated_gradient_trapezoidal_jacobian(
@@ -203,20 +208,17 @@ def integrated_gradient_trapezoidal_jacobian(
f_in_hat.requires_grad_(True)
# Prepare integral
- alphas, interval_size = _calc_integration_intervals(
+ alphas, weights = _calc_integration_intervals(
n_intervals, integral_boundary_relative_epsilon=1e-3
)
if edge_formula == "october":
f_out_hat_const = module_hat(f_in_hat)
- for alpha_index, alpha in tqdm(
- enumerate(alphas), total=len(alphas), desc="Integration steps (alphas)", leave=False
+ for alpha_index, (alpha, weight) in tqdm(
+ enumerate(zip(alphas, weights)),
+ total=len(alphas),
+ desc="Integration steps (alphas)",
+ leave=False,
):
- # As per the trapezoidal rule, multiply the endpoints by 1/2
- # (unless we're taking a point estimate at alpha=0.5)
- scaler = (
- 0.5 if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals) else 1
- )
-
einsum_pattern = "bpj,bpj->j" if f_in_hat.ndim == 3 else "bj,bj->j"
# Normalize by the dataset size and the number of positions (if the input has a position dim)
normalization_factor = f_in_hat.shape[1] * dataset_size if has_pos else dataset_size
@@ -247,8 +249,7 @@ def integrated_gradient_trapezoidal_jacobian(
i_grad = (
torch.autograd.grad(f_out_hat_norm[i], alpha_f_in_hat, retain_graph=True)[0]
/ normalization_factor
- * interval_size
- * scaler
+ * weight
)
with torch.inference_mode():
E = torch.einsum(einsum_pattern, i_grad, f_in_hat)
@@ -285,25 +286,16 @@ def integrated_gradient_trapezoidal_jacobian(
else torch.zeros(batch_size, out_hidden_size_comb_trunc, in_hidden_size_comb_trunc)
)
# Integral
- for alpha_index, alpha in tqdm(
- enumerate(alphas), total=len(alphas), desc="Integration steps (alphas)", leave=False
+ for alpha_index, (alpha, weight) in tqdm(
+ enumerate(zip(alphas, weights)),
+ total=len(alphas),
+ desc="Integration steps (alphas)",
+ leave=False,
):
- # As per the trapezoidal rule, multiply the endpoints by 1/2
- # (unless we're taking a point estimate at alpha=0.5)
- scaler = (
- 0.5 if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals) else 1
- )
- #
# We have to compute inputs from f_hat to make autograd work
alpha_f_in_hat = alpha * f_in_hat
f_out_alpha_hat = module_hat(alpha_f_in_hat)
- # As per the trapezoidal rule, multiply the endpoints by 1/2 (unless we're taking a point
- # estimate at alpha=1)
- scaler = (
- 0.5 if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals) else 1
- )
-
# Take the derivative of the (i, t) element (output dim and output pos) of the output
# Note that t (output pos) is different from p (tprime, input pos)
for out_dim in range(out_hidden_size_comb_trunc):
@@ -322,9 +314,7 @@ def integrated_gradient_trapezoidal_jacobian(
# Sum over tprime (p, input pos) as per Lucius' formula (A.18)
with torch.inference_mode():
- inner_token_sum = torch.einsum(
- einsum_pattern, grad_0 * interval_size * scaler, f_in_hat
- )
+ inner_token_sum = torch.einsum(einsum_pattern, grad_0 * weight, f_in_hat)
# We have a minus sign in front of the IG integral, see e.g. the definition of g_j
# in equation (3.27)
inner_token_sums[:, token_index, out_dim, :] -= inner_token_sum.to(
@@ -395,11 +385,9 @@ def integrated_gradient_trapezoidal_norm(
in_grads = torch.zeros_like(torch.cat(inputs, dim=-1))
- alphas, interval_size = _calc_integration_intervals(
- n_intervals, integral_boundary_relative_epsilon
- )
+ alphas, weights = _calc_integration_intervals(n_intervals, integral_boundary_relative_epsilon)
- for alpha_index, alpha in enumerate(alphas):
+ for alpha_index, (alpha, weight) in enumerate(zip(alphas, weights)):
# Compute f^{l+1}(f^l(alpha x))
alpha_inputs = tuple(alpha * x for x in inputs)
output_alpha = module(*alpha_inputs)
@@ -443,7 +431,7 @@ def integrated_gradient_trapezoidal_norm(
if n_intervals > 0 and (alpha_index == 0 or alpha_index == n_intervals):
alpha_in_grads = 0.5 * alpha_in_grads
- in_grads += alpha_in_grads * interval_size
+ in_grads += alpha_in_grads * weight
for x in alpha_inputs:
assert x.grad is not None, "Input grad should not be None."
Weights as a general way to specify weights for each integration step. The current way is clunky and limited to trapezoidal rule, do this when we implement different integration. I think it should be roughly like shown here though this fails some tests. Not urgent, do later.
New function
Diff file