Consider an LLM P() that gives a probability distribution over token sequences. For a given imposed token sequence x_0, we are interested in the reachable set of outputs y = \arg\max P(y | u + x_0) where u is a control input text sequence of k tokens or fewer.
This greedy forward reachability algorithm takes an initial state string x_0 and applies something like our greedy back generation algorithm to explore the reachable set in an open-ended manner.
For reference, the greedy forward reachability algorithm tries to find a prompt u that steers the model to produce y from initial state x by:
# goal: get LLM to produce y_star from init state u+x_0
u = ''
while True:
for token in vocabulary:
logits = P(y | token + u + x_0)
token_score_i = ce_loss(logits, y_star)
best_token = argmin_i (token_score_i)
u = best_token + u
# check if we have reached y as the argmax or not.
# if reached, return the optimal prompt u.
We are now interested in open-ended reachability analysis. A set of prompts U is optimized such that the reachable set of y = \arg\max P(y | u + x_0) for u in U is maximized. We denote {y | y = \arg\max P(y | u + x_0), u in U_t} =: R_t where U_t is the set of prompts at time t and R_t is the set of prompts at time t, and t is time in the optimization/exploration algorithm.
t=1: Exhaustively iterate through all single-token u, and get logits = P(y | x_0, u). Determine if any u lead to new argmaxes not already in the reachable set R_t. If so, add to U and update R_t.
Compute loss = -ce_loss(unif(R_t), answer_logits) + entropy(answer_logits[~R_t]). We want to drive up the CE loss between the answer logits and the already reached R_t, and we want to make the distribution more peaked (low entropy) in unreached ~R_t.
Keep some portion of the lowest loss prompts (even if they didn't lead to new argmaxes).
Sample one of the lowest loss prompts, try adding all possible single-token back extensions.
It would probably be good to store this as a tree structure. Probably some dicts would be easiest. We start with a list of top-level single-token extension dictionaries:
where 'children' is a list of dictionaries each with the same format. It's good to track the 'last update' on the loss, since it will change as R_t expands.
The loss here is is computed as follows: Let answer_logits be the vocabulary-sized vector of logits from P(y | u + x_0). These logits are better if they (1) reduce the probability of tokens in R_t and (2) increase the spikiness (reduce the entropy) of the distribution over tokens not in R_t. Therefore our loss would be -ce_loss(logits, unif(R_t)) + entropy(logits[~R_t]).
Consider an LLM P() that gives a probability distribution over token sequences. For a given imposed token sequence x_0, we are interested in the reachable set of outputs y = \arg\max P(y | u + x_0) where u is a control input text sequence of k tokens or fewer.
This greedy forward reachability algorithm takes an initial state string x_0 and applies something like our greedy back generation algorithm to explore the reachable set in an open-ended manner.
For reference, the greedy forward reachability algorithm tries to find a prompt u that steers the model to produce y from initial state x by:
We are now interested in open-ended reachability analysis. A set of prompts U is optimized such that the reachable set of y = \arg\max P(y | u + x_0) for u in U is maximized. We denote {y | y = \arg\max P(y | u + x_0), u in U_t} =: R_t where U_t is the set of prompts at time t and R_t is the set of prompts at time t, and t is time in the optimization/exploration algorithm.
It would probably be good to store this as a tree structure. Probably some dicts would be easiest. We start with a list of top-level single-token extension dictionaries:
where 'children' is a list of dictionaries each with the same format. It's good to track the 'last update' on the loss, since it will change as R_t expands.
The loss here is is computed as follows: Let
answer_logits
be the vocabulary-sized vector of logits from P(y | u + x_0). These logits are better if they (1) reduce the probability of tokens in R_t and (2) increase the spikiness (reduce the entropy) of the distribution over tokens not in R_t. Therefore our loss would be-ce_loss(logits, unif(R_t)) + entropy(logits[~R_t])
.