stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
http://pyvene.ai
Apache License 2.0
600 stars 56 forks source link

[P2] Intervening through time recurrently without teacher-forcing #35

Open frankaging opened 8 months ago

frankaging commented 8 months ago

Descriptions: Currently, we support a limited intervention use case on stateful models such as GRU. For instance, after the intervention, although the causal effect of the intervening site would ripple through time, we assume the inputs to be the same as before intervention. This is fine if the task setup doesn't care about the inputs, or it is simply input agnostic when generating, or allows teacher forcing (forcing to discard) when generating.

Here are some illustrations. Right now, we can support cross-time interventions as,

example 1:
(hiddens)    h1, h2, h3, h4, h5, h6
(inputs)       x1, x2,  x3, x4, x5, x6
                    ^
                     |
                     ---
                       |
                       v
example 2:
(hiddens)    h1', h2', h3', h4', h5', h6'
(inputs)       x1', x2',  x3', x4', x5',  x6'

where we take h3' from the second example to intervene in h2 from the first example through time. We then also update h3 to h6 after the intervention. However, we assume x3 to x6 still use the inputs from the example 1. This is acceptable, if during training, x3 to x6 are agnostic in terms of the model's generation (e.g., x2 is some trigger token so the model is in generation mode).

However, this is not ideal. Ideally, if we are dealing with autoregressive LMs, we want x3 to be the intervened model output at the previous step. This requires the model to pass gradients through time. One simple solution is to update the model to do Gumbel-Softmax to softly select the token and pass it to the next time step as the input embedding.

The change may be only on the modeling side. We need to change the model to do soft token selection which allows gradients. However, this is compatible with the library since only in intervention mode, does this input-based unrolling make sense.

frankaging commented 8 months ago

memo comments from Noah offline: if the counterfactual dataset comes with golden labels for the intervened outputs through time, then it is also a case where this gradients through time is not necessary. moving this to P2 as it is not prioritized for now.