saprmarks / feature-circuits

MIT License
74 stars 16 forks source link

Potential bugs and confusion with `attribution.jvp` #10

Open JacksonKaunismaa opened 1 month ago

JacksonKaunismaa commented 1 month ago

I have been working through the paper trying to understand things and examining the code for computing edge weights and I believe I have discovered some unexpected behavior, as well as some other confusing areas. I would greatly appreciate some clarification on where I'm going wrong here.

Unexpected zeros

We wish to compute the edge weights between two residual layers, as in
https://github.com/saprmarks/feature-circuits/blob/c1a9b7b3a400363d3e83b4319b989a4949959fe7/circuit.py#L157-L208 This is done by computing the direct effect of the residual at layer n to the residual at layer n +1 (line 199) and then subtracting off the indirect contributions of the residual in layer n on residual n+1 via both the MLP and the attention submodules (line 204).

However, these indirect contributions (RMR_effect and RAR_effect) appear to be always all zeros in my testing with all the provided supervised datasets in data regardless of edge/node thresholds, batch size, example length, or input. Is this behavior expected?

downstream_feat indexing

It also seems to me that in this section of attribution.jvp, downstream_feat is being used in two unrelated ways. https://github.com/saprmarks/feature-circuits/blob/c1a9b7b3a400363d3e83b4319b989a4949959fe7/attribution.py#L346-L357 Using the notation of equation 6 from the paper, on line 350, it indexes into left_vec, selecting the gradient of some particular feature in d with respect to the features in intermediate node m. Then, on line 357, we index into to_backprop with downstream_feat, where to_backprop is an element-wise product of an intermediate node gradient m and the current activation of m. If downstream_feat corresponds to some feature in node d, why can we use it to index into intermediate node m?

vjv vs. jv

Finally, I believe there could be some issue with how vjv and jv are computed inside of attribution.jvp. As far as I can tell (and running the code confirms this), by the return statement here https://github.com/saprmarks/feature-circuits/blob/c1a9b7b3a400363d3e83b4319b989a4949959fe7/attribution.py#L388-L391 vjv_indices and vjv_values are identical to jv_indices and jv_values. Therefore, the only way that MR_effect and MR_grad differ in https://github.com/saprmarks/feature-circuits/blob/c1a9b7b3a400363d3e83b4319b989a4949959fe7/circuit.py#L162 is their size, but each having the same underlying values and indices. However, on lines 186 and 196 in https://github.com/saprmarks/feature-circuits/blob/c1a9b7b3a400363d3e83b4319b989a4949959fe7/circuit.py#L179-L197, this differing size matters when we reshape these *_grad variables. If we had instead applied the same reshape to the analogous *_effect variables, we would end up with 2 different tensors that have the same values but rearranged and moved around in a strange permutation that does not seem to correspond to how these variables were computed in the first place. Is this behavior expected?

saprmarks commented 1 month ago

Many, many thanks—you're absolutely right about your "Unexpected zeros" and "vjv vs. jv" points: it looks like the computation of the correction terms for the inter-residual-stream effects was just completely bugged.

I've attempted a bugfix in the new jvp-fix branch. I've tested that it runs and doesn't have the issues you raised above (e.g. there are nonzero values for RMR and RAR). On brief inspection, it also looks like the new implementation coincides with the results of the old implementation for AR, MR, RA, and MA (and I'm relatively confident that the old implementation gave correct results for these). But better testing is evidently needed, which I'm hoping to get to in the next few days.

One note about the use of jvp for computing RAR and RMR.

RMR_effect = jvp(
    input=clean,
    model=model,
    dictionaries=dictionaries,
    downstream_submod=mlp,
    downstream_features=features_by_submod[resid],
    upstream_submod=prev_resid,
    left_vec={feat_idx : unflatten(MR_grad[feat_idx].to_dense()) for feat_idx in features_by_submod[resid]},
    right_vecdeltas[prev_resid],
)

(Argument names added for clarity.) One weird, but intentional, thing that's happening here is that the downstream_submod is not the same submodule that the downstream_features are coming from. What's going on here is that we need to compute, for each $\mathbf{d}$ and $\mathbf{u}$ the RHS of the expression $$\sum{\mathbf{m}} \nabla\mathbf{d}m \nabla\mathbf{m}\mathbf{d} \nabla\mathbf{u}\mathbf{m} = \nabla\mathbf{u}\left(\sum{\mathbf{m}} \nabla\mathbf{d}m\nabla\mathbf{m}\mathbf{d} \mathbf{m}\right)$$ where the gradients inside the parentheses on the RHS are treated as a constant. This requires iterating over downstream features $\mathbf{d}$ (here, from the next layer's residual stream) but backpropping from (a weighted sum of) $\mathbf{m}$ (here, MLP or attention activations).

In other words, when reading the implementation of jvp in this case, you should keep in mind that downstream_feat doesn't correspond to a feature of the downstream_submod—it's actually a feature of a deeper submodule.

(Of course, the old implementation still wasn't correct; for instance, it didn't sum over the $\mathbf{m}$ dimension!)