amanb2000 / Magic_Words

Code for the paper "What's the Magic Word? A Control Theory of LLM Prompting"
MIT License
90 stars 12 forks source link

Greedy Forward Reachability #7

Closed amanb2000 closed 3 months ago

amanb2000 commented 3 months ago

Consider an LLM P() that gives a probability distribution over token sequences. For a given imposed token sequence x_0, we are interested in the reachable set of outputs y = \arg\max P(y | u + x_0) where u is a control input text sequence of k tokens or fewer.

This greedy forward reachability algorithm takes an initial state string x_0 and applies something like our greedy back generation algorithm to explore the reachable set in an open-ended manner.

For reference, the greedy forward reachability algorithm tries to find a prompt u that steers the model to produce y from initial state x by:

# goal: get LLM to produce y_star from init state u+x_0
u = ''
while True: 
   for token in vocabulary: 
      logits = P(y | token + u + x_0)
      token_score_i = ce_loss(logits, y_star)
   best_token = argmin_i (token_score_i)
   u = best_token + u
   # check if we have reached y as the argmax or not. 
   # if reached, return the optimal prompt u.

We are now interested in open-ended reachability analysis. A set of prompts U is optimized such that the reachable set of y = \arg\max P(y | u + x_0) for u in U is maximized. We denote {y | y = \arg\max P(y | u + x_0), u in U_t} =: R_t where U_t is the set of prompts at time t and R_t is the set of prompts at time t, and t is time in the optimization/exploration algorithm.

  1. t=1: Exhaustively iterate through all single-token u, and get logits = P(y | x_0, u). Determine if any u lead to new argmaxes not already in the reachable set R_t. If so, add to U and update R_t.
  2. Compute loss = -ce_loss(unif(R_t), answer_logits) + entropy(answer_logits[~R_t]). We want to drive up the CE loss between the answer logits and the already reached R_t, and we want to make the distribution more peaked (low entropy) in unreached ~R_t.
  3. Keep some portion of the lowest loss prompts (even if they didn't lead to new argmaxes).
  4. Sample one of the lowest loss prompts, try adding all possible single-token back extensions.

It would probably be good to store this as a tree structure. Probably some dicts would be easiest. We start with a list of top-level single-token extension dictionaries:

[
{'prompt_ids': [0], 'prompt_text': "!", 'children': [{...}, ...], 'loss': 3.2, 'last_update': 2},
{'prompt_ids': [1], 'prompt_text': "?", 'children': [{...}, ...], 'loss': 4.2, 'last_update': 4},
...
]

where 'children' is a list of dictionaries each with the same format. It's good to track the 'last update' on the loss, since it will change as R_t expands.

The loss here is is computed as follows: Let answer_logits be the vocabulary-sized vector of logits from P(y | u + x_0). These logits are better if they (1) reduce the probability of tokens in R_t and (2) increase the spikiness (reduce the entropy) of the distribution over tokens not in R_t. Therefore our loss would be -ce_loss(logits, unif(R_t)) + entropy(logits[~R_t]).