Allow ActivationCache.get_full_resid_decomposition to receive a project_output_onto tensor that is either a [d_model] tensor or [d_model, num_outputs] tensor, such we multiply the output by that. Internally, rather than taking (neurons W_out), take neurons (W_out @ project_output_onto), this is much more memory efficient.
Motivation
There's a ton of neurons, and creating a [d_mlp, d_model] tensor at every position and batch can blow out your GPU memory fast. This means that if we just want eg the contribution of a neuron to the output logit of the correct next token, we can just feed in that vector and save memory.
This is a bit messy, since there's many ways we might want to do this (eg, having a different output vector per position for each correct next token), but this seems like a good MVP.
Proposal
Allow ActivationCache.get_full_resid_decomposition to receive a
project_output_onto
tensor that is either a [d_model] tensor or [d_model, num_outputs] tensor, such we multiply the output by that. Internally, rather than taking (neurons W_out), take neurons (W_out @ project_output_onto), this is much more memory efficient.Motivation
There's a ton of neurons, and creating a [d_mlp, d_model] tensor at every position and batch can blow out your GPU memory fast. This means that if we just want eg the contribution of a neuron to the output logit of the correct next token, we can just feed in that vector and save memory.
This is a bit messy, since there's many ways we might want to do this (eg, having a different output vector per position for each correct next token), but this seems like a good MVP.