Open frankaging opened 8 months ago
memo comments from Noah offline: if the counterfactual dataset comes with golden labels for the intervened outputs through time, then it is also a case where this gradients through time is not necessary. moving this to P2 as it is not prioritized for now.
Descriptions: Currently, we support a limited intervention use case on stateful models such as GRU. For instance, after the intervention, although the causal effect of the intervening site would ripple through time, we assume the inputs to be the same as before intervention. This is fine if the task setup doesn't care about the inputs, or it is simply input agnostic when generating, or allows teacher forcing (forcing to discard) when generating.
Here are some illustrations. Right now, we can support cross-time interventions as,
where we take
h3'
from the second example to intervene inh2
from the first example through time. We then also updateh3
toh6
after the intervention. However, we assumex3
tox6
still use the inputs from the example 1. This is acceptable, if during training,x3
tox6
are agnostic in terms of the model's generation (e.g.,x2
is some trigger token so the model is in generation mode).However, this is not ideal. Ideally, if we are dealing with autoregressive LMs, we want
x3
to be the intervened model output at the previous step. This requires the model to pass gradients through time. One simple solution is to update the model to doGumbel-Softmax
to softly select the token and pass it to the next time step as the input embedding.The change may be only on the modeling side. We need to change the model to do soft token selection which allows gradients. However, this is compatible with the library since only in intervention mode, does this input-based unrolling make sense.